aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/net
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net')
-rw-r--r--libgo/go/net/addrselect.go1
-rw-r--r--libgo/go/net/addrselect_test.go1
-rw-r--r--libgo/go/net/cgo_aix.go1
-rw-r--r--libgo/go/net/cgo_android.go1
-rw-r--r--libgo/go/net/cgo_bsd.go3
-rw-r--r--libgo/go/net/cgo_linux.go1
-rw-r--r--libgo/go/net/cgo_netbsd.go1
-rw-r--r--libgo/go/net/cgo_openbsd.go1
-rw-r--r--libgo/go/net/cgo_resnew.go3
-rw-r--r--libgo/go/net/cgo_resold.go3
-rw-r--r--libgo/go/net/cgo_socknew.go3
-rw-r--r--libgo/go/net/cgo_sockold.go3
-rw-r--r--libgo/go/net/cgo_solaris.go1
-rw-r--r--libgo/go/net/cgo_stub.go1
-rw-r--r--libgo/go/net/cgo_unix.go5
-rw-r--r--libgo/go/net/cgo_unix_test.go3
-rw-r--r--libgo/go/net/cgo_windows.go1
-rw-r--r--libgo/go/net/conf.go4
-rw-r--r--libgo/go/net/conf_netcgo.go1
-rw-r--r--libgo/go/net/conf_test.go1
-rw-r--r--libgo/go/net/conn_test.go8
-rw-r--r--libgo/go/net/dial_test.go73
-rw-r--r--libgo/go/net/dial_unix_test.go13
-rw-r--r--libgo/go/net/dnsclient.go21
-rw-r--r--libgo/go/net/dnsclient_unix.go1
-rw-r--r--libgo/go/net/dnsclient_unix_test.go42
-rw-r--r--libgo/go/net/dnsconfig_unix.go1
-rw-r--r--libgo/go/net/dnsconfig_unix_test.go1
-rw-r--r--libgo/go/net/dnsname_test.go1
-rw-r--r--libgo/go/net/error_plan9_test.go4
-rw-r--r--libgo/go/net/error_posix.go1
-rw-r--r--libgo/go/net/error_posix_test.go1
-rw-r--r--libgo/go/net/error_test.go16
-rw-r--r--libgo/go/net/error_unix.go1
-rw-r--r--libgo/go/net/error_unix_test.go6
-rw-r--r--libgo/go/net/error_windows_test.go12
-rw-r--r--libgo/go/net/example_test.go206
-rw-r--r--libgo/go/net/external_test.go1
-rw-r--r--libgo/go/net/fcntl_libc_test.go1
-rw-r--r--libgo/go/net/fcntl_syscall_test.go1
-rw-r--r--libgo/go/net/fd_posix.go48
-rw-r--r--libgo/go/net/fd_unix.go9
-rw-r--r--libgo/go/net/file_stub.go1
-rw-r--r--libgo/go/net/file_test.go36
-rw-r--r--libgo/go/net/file_unix.go1
-rw-r--r--libgo/go/net/hook_unix.go1
-rw-r--r--libgo/go/net/hosts.go15
-rw-r--r--libgo/go/net/hosts_test.go4
-rw-r--r--libgo/go/net/http/cgi/child.go4
-rw-r--r--libgo/go/net/http/cgi/host.go7
-rw-r--r--libgo/go/net/http/cgi/host_test.go10
-rw-r--r--libgo/go/net/http/cgi/posix_test.go1
-rw-r--r--libgo/go/net/http/client.go1
-rw-r--r--libgo/go/net/http/client_test.go188
-rw-r--r--libgo/go/net/http/clientserver_test.go44
-rw-r--r--libgo/go/net/http/cookie.go60
-rw-r--r--libgo/go/net/http/cookie_test.go27
-rw-r--r--libgo/go/net/http/export_test.go7
-rw-r--r--libgo/go/net/http/fs.go22
-rw-r--r--libgo/go/net/http/fs_test.go25
-rw-r--r--libgo/go/net/http/h2_bundle.go1832
-rw-r--r--libgo/go/net/http/header.go11
-rw-r--r--libgo/go/net/http/header_test.go13
-rw-r--r--libgo/go/net/http/httptrace/trace.go2
-rw-r--r--libgo/go/net/http/httputil/dump.go2
-rw-r--r--libgo/go/net/http/httputil/dump_test.go2
-rw-r--r--libgo/go/net/http/httputil/reverseproxy.go5
-rw-r--r--libgo/go/net/http/httputil/reverseproxy_test.go20
-rw-r--r--libgo/go/net/http/internal/chunked.go16
-rw-r--r--libgo/go/net/http/internal/chunked_test.go28
-rw-r--r--libgo/go/net/http/internal/testcert/testcert.go69
-rw-r--r--libgo/go/net/http/main_test.go9
-rw-r--r--libgo/go/net/http/omithttp2.go7
-rw-r--r--libgo/go/net/http/pprof/pprof.go2
-rw-r--r--libgo/go/net/http/pprof/pprof_test.go5
-rw-r--r--libgo/go/net/http/request.go44
-rw-r--r--libgo/go/net/http/request_test.go8
-rw-r--r--libgo/go/net/http/requestwrite_test.go2
-rw-r--r--libgo/go/net/http/response.go15
-rw-r--r--libgo/go/net/http/response_test.go16
-rw-r--r--libgo/go/net/http/roundtrip.go1
-rw-r--r--libgo/go/net/http/roundtrip_js.go43
-rw-r--r--libgo/go/net/http/serve_test.go219
-rw-r--r--libgo/go/net/http/server.go101
-rw-r--r--libgo/go/net/http/server_test.go53
-rw-r--r--libgo/go/net/http/transfer.go12
-rw-r--r--libgo/go/net/http/transport.go20
-rw-r--r--libgo/go/net/http/transport_default_js.go17
-rw-r--r--libgo/go/net/http/transport_default_other.go17
-rw-r--r--libgo/go/net/http/transport_test.go39
-rw-r--r--libgo/go/net/http/triv.go1
-rw-r--r--libgo/go/net/interface_aix.go2
-rw-r--r--libgo/go/net/interface_bsd.go1
-rw-r--r--libgo/go/net/interface_bsd_test.go1
-rw-r--r--libgo/go/net/interface_bsdvar.go1
-rw-r--r--libgo/go/net/interface_freebsd.go11
-rw-r--r--libgo/go/net/interface_stub.go1
-rw-r--r--libgo/go/net/interface_test.go1
-rw-r--r--libgo/go/net/interface_unix_test.go1
-rw-r--r--libgo/go/net/internal/socktest/main_test.go1
-rw-r--r--libgo/go/net/internal/socktest/main_unix_test.go1
-rw-r--r--libgo/go/net/internal/socktest/switch_posix.go1
-rw-r--r--libgo/go/net/internal/socktest/switch_stub.go1
-rw-r--r--libgo/go/net/internal/socktest/switch_unix.go1
-rw-r--r--libgo/go/net/internal/socktest/sys_cloexec.go1
-rw-r--r--libgo/go/net/internal/socktest/sys_unix.go1
-rw-r--r--libgo/go/net/ip.go5
-rw-r--r--libgo/go/net/ip_test.go4
-rw-r--r--libgo/go/net/iprawsock_posix.go1
-rw-r--r--libgo/go/net/iprawsock_test.go1
-rw-r--r--libgo/go/net/ipsock_posix.go107
-rw-r--r--libgo/go/net/listen_test.go32
-rw-r--r--libgo/go/net/lookup.go82
-rw-r--r--libgo/go/net/lookup_fake.go1
-rw-r--r--libgo/go/net/lookup_plan9.go10
-rw-r--r--libgo/go/net/lookup_test.go97
-rw-r--r--libgo/go/net/lookup_unix.go1
-rw-r--r--libgo/go/net/lookup_windows.go14
-rw-r--r--libgo/go/net/lookup_windows_test.go12
-rw-r--r--libgo/go/net/mail/message.go4
-rw-r--r--libgo/go/net/main_cloexec_test.go1
-rw-r--r--libgo/go/net/main_conf_test.go1
-rw-r--r--libgo/go/net/main_noconf_test.go1
-rw-r--r--libgo/go/net/main_posix_test.go5
-rw-r--r--libgo/go/net/main_test.go8
-rw-r--r--libgo/go/net/main_unix_test.go1
-rw-r--r--libgo/go/net/mockserver_test.go168
-rw-r--r--libgo/go/net/net.go14
-rw-r--r--libgo/go/net/net_fake.go33
-rw-r--r--libgo/go/net/net_test.go79
-rw-r--r--libgo/go/net/netip/export_test.go30
-rw-r--r--libgo/go/net/netip/fuzz_test.go353
-rw-r--r--libgo/go/net/netip/inlining_test.go110
-rw-r--r--libgo/go/net/netip/leaf_alts.go54
-rw-r--r--libgo/go/net/netip/netip.go1498
-rw-r--r--libgo/go/net/netip/netip_pkg_test.go359
-rw-r--r--libgo/go/net/netip/netip_test.go1974
-rw-r--r--libgo/go/net/netip/slow_test.go190
-rw-r--r--libgo/go/net/netip/uint128.go92
-rw-r--r--libgo/go/net/netip/uint128_test.go89
-rw-r--r--libgo/go/net/nss.go1
-rw-r--r--libgo/go/net/nss_test.go1
-rw-r--r--libgo/go/net/packetconn_test.go41
-rw-r--r--libgo/go/net/parse.go33
-rw-r--r--libgo/go/net/parse_test.go27
-rw-r--r--libgo/go/net/platform_test.go12
-rw-r--r--libgo/go/net/port_unix.go1
-rw-r--r--libgo/go/net/protoconn_test.go10
-rw-r--r--libgo/go/net/rawconn_stub_test.go1
-rw-r--r--libgo/go/net/rawconn_test.go16
-rw-r--r--libgo/go/net/rawconn_unix_test.go1
-rw-r--r--libgo/go/net/rpc/client.go22
-rw-r--r--libgo/go/net/rpc/client_test.go8
-rw-r--r--libgo/go/net/rpc/debug.go2
-rw-r--r--libgo/go/net/rpc/jsonrpc/all_test.go6
-rw-r--r--libgo/go/net/rpc/jsonrpc/client.go12
-rw-r--r--libgo/go/net/rpc/jsonrpc/server.go10
-rw-r--r--libgo/go/net/rpc/server.go52
-rw-r--r--libgo/go/net/rpc/server_test.go8
-rw-r--r--libgo/go/net/sendfile_stub.go3
-rw-r--r--libgo/go/net/sendfile_test.go26
-rw-r--r--libgo/go/net/sendfile_unix_alt.go3
-rw-r--r--libgo/go/net/server_test.go69
-rw-r--r--libgo/go/net/smtp/smtp.go10
-rw-r--r--libgo/go/net/smtp/smtp_test.go2
-rw-r--r--libgo/go/net/sock_bsd.go1
-rw-r--r--libgo/go/net/sock_cloexec.go1
-rw-r--r--libgo/go/net/sock_posix.go1
-rw-r--r--libgo/go/net/sock_stub.go1
-rw-r--r--libgo/go/net/sockaddr_posix.go1
-rw-r--r--libgo/go/net/sockopt_bsd.go1
-rw-r--r--libgo/go/net/sockopt_posix.go1
-rw-r--r--libgo/go/net/sockopt_stub.go1
-rw-r--r--libgo/go/net/sockoptip_bsdvar.go1
-rw-r--r--libgo/go/net/sockoptip_posix.go1
-rw-r--r--libgo/go/net/sockoptip_stub.go7
-rw-r--r--libgo/go/net/splice_stub.go1
-rw-r--r--libgo/go/net/splice_test.go74
-rw-r--r--libgo/go/net/sys_cloexec.go1
-rw-r--r--libgo/go/net/tcpsock.go26
-rw-r--r--libgo/go/net/tcpsock_posix.go1
-rw-r--r--libgo/go/net/tcpsock_test.go26
-rw-r--r--libgo/go/net/tcpsock_unix_test.go6
-rw-r--r--libgo/go/net/tcpsockopt_posix.go1
-rw-r--r--libgo/go/net/tcpsockopt_stub.go1
-rw-r--r--libgo/go/net/tcpsockopt_unix.go1
-rw-r--r--libgo/go/net/textproto/reader.go28
-rw-r--r--libgo/go/net/textproto/textproto.go2
-rw-r--r--libgo/go/net/textproto/writer.go2
-rw-r--r--libgo/go/net/timeout_test.go469
-rw-r--r--libgo/go/net/udpsock.go81
-rw-r--r--libgo/go/net/udpsock_plan9.go32
-rw-r--r--libgo/go/net/udpsock_posix.go145
-rw-r--r--libgo/go/net/udpsock_test.go168
-rw-r--r--libgo/go/net/unixsock_posix.go1
-rw-r--r--libgo/go/net/unixsock_readmsg_cloexec.go1
-rw-r--r--libgo/go/net/unixsock_readmsg_cmsg_cloexec.go1
-rw-r--r--libgo/go/net/unixsock_readmsg_other.go1
-rw-r--r--libgo/go/net/unixsock_readmsg_test.go1
-rw-r--r--libgo/go/net/unixsock_test.go32
-rw-r--r--libgo/go/net/unixsock_windows_test.go10
-rw-r--r--libgo/go/net/url/url.go60
-rw-r--r--libgo/go/net/url/url_test.go20
-rw-r--r--libgo/go/net/write_unix_test.go1
-rw-r--r--libgo/go/net/writev_test.go6
-rw-r--r--libgo/go/net/writev_unix.go1
206 files changed, 8428 insertions, 2215 deletions
diff --git a/libgo/go/net/addrselect.go b/libgo/go/net/addrselect.go
index 4603c55..e910181 100644
--- a/libgo/go/net/addrselect.go
+++ b/libgo/go/net/addrselect.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
// Minimal RFC 6724 address selection.
diff --git a/libgo/go/net/addrselect_test.go b/libgo/go/net/addrselect_test.go
index 18784fe..a958e2e 100644
--- a/libgo/go/net/addrselect_test.go
+++ b/libgo/go/net/addrselect_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/cgo_aix.go b/libgo/go/net/cgo_aix.go
index 577649f..6ee0f09 100644
--- a/libgo/go/net/cgo_aix.go
+++ b/libgo/go/net/cgo_aix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo
-// +build cgo,!netgo
package net
diff --git a/libgo/go/net/cgo_android.go b/libgo/go/net/cgo_android.go
index 4b1a2e3..5ab8b5f 100644
--- a/libgo/go/net/cgo_android.go
+++ b/libgo/go/net/cgo_android.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo
-// +build cgo,!netgo
package net
diff --git a/libgo/go/net/cgo_bsd.go b/libgo/go/net/cgo_bsd.go
index 1268c89..830e589 100644
--- a/libgo/go/net/cgo_bsd.go
+++ b/libgo/go/net/cgo_bsd.go
@@ -3,9 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && (darwin || dragonfly || freebsd)
-// +build cgo
-// +build !netgo
-// +build darwin dragonfly freebsd
package net
diff --git a/libgo/go/net/cgo_linux.go b/libgo/go/net/cgo_linux.go
index 4b45dad..5d67699 100644
--- a/libgo/go/net/cgo_linux.go
+++ b/libgo/go/net/cgo_linux.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !android && cgo && !netgo
-// +build !android,cgo,!netgo
package net
diff --git a/libgo/go/net/cgo_netbsd.go b/libgo/go/net/cgo_netbsd.go
index e23899d..4778811 100644
--- a/libgo/go/net/cgo_netbsd.go
+++ b/libgo/go/net/cgo_netbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo
-// +build cgo,!netgo
package net
diff --git a/libgo/go/net/cgo_openbsd.go b/libgo/go/net/cgo_openbsd.go
index 3714793..03392e8 100644
--- a/libgo/go/net/cgo_openbsd.go
+++ b/libgo/go/net/cgo_openbsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo
-// +build cgo,!netgo
package net
diff --git a/libgo/go/net/cgo_resnew.go b/libgo/go/net/cgo_resnew.go
index 6611fd7..985faee 100644
--- a/libgo/go/net/cgo_resnew.go
+++ b/libgo/go/net/cgo_resnew.go
@@ -3,9 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && (aix || darwin || hurd || (linux && !android) || netbsd || solaris)
-// +build cgo
-// +build !netgo
-// +build aix darwin hurd linux,!android netbsd solaris
package net
diff --git a/libgo/go/net/cgo_resold.go b/libgo/go/net/cgo_resold.go
index 33f664c..b65e020 100644
--- a/libgo/go/net/cgo_resold.go
+++ b/libgo/go/net/cgo_resold.go
@@ -3,9 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && (android || freebsd || dragonfly || openbsd)
-// +build cgo
-// +build !netgo
-// +build android freebsd dragonfly openbsd
package net
diff --git a/libgo/go/net/cgo_socknew.go b/libgo/go/net/cgo_socknew.go
index 84b40c9..2c3ab63 100644
--- a/libgo/go/net/cgo_socknew.go
+++ b/libgo/go/net/cgo_socknew.go
@@ -3,9 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && (android || linux || solaris)
-// +build cgo
-// +build !netgo
-// +build android linux solaris
package net
diff --git a/libgo/go/net/cgo_sockold.go b/libgo/go/net/cgo_sockold.go
index 703b41b..461ecb4 100644
--- a/libgo/go/net/cgo_sockold.go
+++ b/libgo/go/net/cgo_sockold.go
@@ -3,9 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && (aix || darwin || dragonfly || freebsd || hurd || netbsd || openbsd)
-// +build cgo
-// +build !netgo
-// +build aix darwin dragonfly freebsd hurd netbsd openbsd
package net
diff --git a/libgo/go/net/cgo_solaris.go b/libgo/go/net/cgo_solaris.go
index 95d5db5..95a23cf 100644
--- a/libgo/go/net/cgo_solaris.go
+++ b/libgo/go/net/cgo_solaris.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo
-// +build cgo,!netgo
package net
diff --git a/libgo/go/net/cgo_stub.go b/libgo/go/net/cgo_stub.go
index 039e4be..cc84ca4 100644
--- a/libgo/go/net/cgo_stub.go
+++ b/libgo/go/net/cgo_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !cgo || netgo
-// +build !cgo netgo
package net
diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go
index 462bf12..26b3da3 100644
--- a/libgo/go/net/cgo_unix.go
+++ b/libgo/go/net/cgo_unix.go
@@ -3,9 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && (aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris)
-// +build cgo
-// +build !netgo
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
@@ -352,7 +349,7 @@ func cgoLookupAddrPTR(addr string, sa *syscall.RawSockaddr, salen syscall.Sockle
break
}
}
- return []string{absDomainName(b)}, nil
+ return []string{absDomainName(string(b))}, nil
}
func cgoReverseLookup(result chan<- reverseLookupResult, addr string, sa *syscall.RawSockaddr, salen syscall.Socklen_t) {
diff --git a/libgo/go/net/cgo_unix_test.go b/libgo/go/net/cgo_unix_test.go
index 98b3b4a..5264fcd 100644
--- a/libgo/go/net/cgo_unix_test.go
+++ b/libgo/go/net/cgo_unix_test.go
@@ -3,9 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && (aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris)
-// +build cgo
-// +build !netgo
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/cgo_windows.go b/libgo/go/net/cgo_windows.go
index 1fd1f297..6bb6cbb 100644
--- a/libgo/go/net/cgo_windows.go
+++ b/libgo/go/net/cgo_windows.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build cgo && !netgo
-// +build cgo,!netgo
package net
diff --git a/libgo/go/net/conf.go b/libgo/go/net/conf.go
index fe7ebf1..6edecaf 100644
--- a/libgo/go/net/conf.go
+++ b/libgo/go/net/conf.go
@@ -3,12 +3,12 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
import (
"internal/bytealg"
+ "internal/godebug"
"os"
"runtime"
"sync"
@@ -287,7 +287,7 @@ func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrde
// cgo+2 // same, but debug level 2
// etc.
func goDebugNetDNS() (dnsMode string, debugLevel int) {
- goDebug := goDebugString("netdns")
+ goDebug := godebug.Get("netdns")
parsePart := func(s string) {
if s == "" {
return
diff --git a/libgo/go/net/conf_netcgo.go b/libgo/go/net/conf_netcgo.go
index c705152..3447a87 100644
--- a/libgo/go/net/conf_netcgo.go
+++ b/libgo/go/net/conf_netcgo.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build netcgo
-// +build netcgo
package net
diff --git a/libgo/go/net/conf_test.go b/libgo/go/net/conf_test.go
index f5e4d86..8c2d3ce 100644
--- a/libgo/go/net/conf_test.go
+++ b/libgo/go/net/conf_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/conn_test.go b/libgo/go/net/conn_test.go
index 45e271c..d168dda 100644
--- a/libgo/go/net/conn_test.go
+++ b/libgo/go/net/conn_test.go
@@ -6,7 +6,6 @@
// tag.
//go:build !js
-// +build !js
package net
@@ -18,7 +17,7 @@ import (
// someTimeout is used just to test that net.Conn implementations
// don't explode when their SetFooDeadline methods are called.
// It isn't actually used for testing timeouts.
-const someTimeout = 10 * time.Second
+const someTimeout = 1 * time.Hour
func TestConnAndListener(t *testing.T) {
for i, network := range []string{"tcp", "unix", "unixpacket"} {
@@ -27,10 +26,7 @@ func TestConnAndListener(t *testing.T) {
continue
}
- ls, err := newLocalServer(network)
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, network)
defer ls.teardown()
ch := make(chan error, 1)
handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) }
diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go
index 723038c..b9aead0 100644
--- a/libgo/go/net/dial_test.go
+++ b/libgo/go/net/dial_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -60,10 +59,7 @@ func TestProhibitionaryDialArg(t *testing.T) {
}
func TestDialLocal(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
_, port, err := SplitHostPort(ln.Addr().String())
if err != nil {
@@ -433,14 +429,15 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
readDeadline = time.Now().Add(5 * time.Second)
}
- var wg sync.WaitGroup
- wg.Add(2)
+ var closed sync.WaitGroup
+ closed.Add(2)
handler := func(dss *dualStackServer, ln Listener) {
// Accept one connection per address.
c, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
+
// The client should close itself, without sending data.
c.SetReadDeadline(readDeadline)
var b [1]byte
@@ -448,7 +445,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
t.Errorf("got %v; want %v", err, io.EOF)
}
c.Close()
- wg.Done()
+ closed.Done()
}
dss, err := newDualStackServer()
if err != nil {
@@ -461,12 +458,16 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
const fallbackDelay = 100 * time.Millisecond
+ var dialing sync.WaitGroup
+ dialing.Add(2)
origTestHookDialTCP := testHookDialTCP
defer func() { testHookDialTCP = origTestHookDialTCP }()
testHookDialTCP = func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
- // Sleep long enough for Happy Eyeballs to kick in, and inhibit cancellation.
+ // Wait until Happy Eyeballs kicks in and both connections are dialing,
+ // and inhibit cancellation.
// This forces dialParallel to juggle two successful connections.
- time.Sleep(fallbackDelay * 2)
+ dialing.Done()
+ dialing.Wait()
// Now ignore the provided context (which will be canceled) and use a
// different one to make sure this completes with a valid connection,
@@ -500,7 +501,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
c.Close()
// The server should've seen both connections.
- wg.Wait()
+ closed.Wait()
}
func TestDialerPartialDeadline(t *testing.T) {
@@ -538,6 +539,9 @@ func TestDialerPartialDeadline(t *testing.T) {
}
}
+// isEADDRINUSE reports whether err is syscall.EADDRINUSE.
+var isEADDRINUSE = func(err error) bool { return false }
+
func TestDialerLocalAddr(t *testing.T) {
if !supportsIPv4() || !supportsIPv6() {
t.Skip("both IPv4 and IPv6 are required")
@@ -593,7 +597,9 @@ func TestDialerLocalAddr(t *testing.T) {
{"tcp", "::1", &UnixAddr{}, &AddrError{Err: "some error"}},
}
+ issue34264Index := -1
if supportsIPv4map() {
+ issue34264Index = len(tests)
tests = append(tests, test{
"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, nil,
})
@@ -615,20 +621,16 @@ func TestDialerLocalAddr(t *testing.T) {
c.Close()
}
}
- var err error
var lss [2]*localServer
for i, network := range []string{"tcp4", "tcp6"} {
- lss[i], err = newLocalServer(network)
- if err != nil {
- t.Fatal(err)
- }
+ lss[i] = newLocalServer(t, network)
defer lss[i].teardown()
if err := lss[i].buildup(handler); err != nil {
t.Fatal(err)
}
}
- for _, tt := range tests {
+ for i, tt := range tests {
d := &Dialer{LocalAddr: tt.laddr}
var addr string
ip := ParseIP(tt.raddr)
@@ -640,7 +642,15 @@ func TestDialerLocalAddr(t *testing.T) {
}
c, err := d.Dial(tt.network, addr)
if err == nil && tt.error != nil || err != nil && tt.error == nil {
- t.Errorf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error)
+ if i == issue34264Index && runtime.GOOS == "freebsd" && isEADDRINUSE(err) {
+ // https://golang.org/issue/34264: FreeBSD through at least version 12.2
+ // has been observed to fail with EADDRINUSE when dialing from an IPv6
+ // local address to an IPv4 remote address.
+ t.Logf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error)
+ t.Logf("(spurious EADDRINUSE ignored on freebsd: see https://golang.org/issue/34264)")
+ } else {
+ t.Errorf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error)
+ }
}
if err != nil {
if perr := parseDialError(err); perr != nil {
@@ -713,10 +723,7 @@ func TestDialerKeepAlive(t *testing.T) {
c.Close()
}
}
- ls, err := newLocalServer("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, "tcp")
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -814,10 +821,7 @@ func TestCancelAfterDial(t *testing.T) {
t.Skip("avoiding time.Sleep")
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
var wg sync.WaitGroup
wg.Add(1)
@@ -920,11 +924,7 @@ func TestDialerControl(t *testing.T) {
if !testableNetwork(network) {
continue
}
- ln, err := newLocalListener(network)
- if err != nil {
- t.Error(err)
- continue
- }
+ ln := newLocalListener(t, network)
defer ln.Close()
d := Dialer{Control: controlOnConnSetup}
c, err := d.Dial(network, ln.Addr().String())
@@ -940,11 +940,7 @@ func TestDialerControl(t *testing.T) {
if !testableNetwork(network) {
continue
}
- c1, err := newLocalPacketListener(network)
- if err != nil {
- t.Error(err)
- continue
- }
+ c1 := newLocalPacketListener(t, network)
if network == "unixgram" {
defer os.Remove(c1.LocalAddr().String())
}
@@ -980,10 +976,7 @@ func (contextWithNonZeroDeadline) Deadline() (time.Time, bool) {
}
func TestDialWithNonZeroDeadline(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
_, port, err := SplitHostPort(ln.Addr().String())
if err != nil {
diff --git a/libgo/go/net/dial_unix_test.go b/libgo/go/net/dial_unix_test.go
index 4b9bc27..45d032c 100644
--- a/libgo/go/net/dial_unix_test.go
+++ b/libgo/go/net/dial_unix_test.go
@@ -3,17 +3,23 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
import (
"context"
+ "errors"
"syscall"
"testing"
"time"
)
+func init() {
+ isEADDRINUSE = func(err error) bool {
+ return errors.Is(err, syscall.EADDRINUSE)
+ }
+}
+
// Issue 16523
func TestDialContextCancelRace(t *testing.T) {
oldConnectFunc := connectFunc
@@ -25,10 +31,7 @@ func TestDialContextCancelRace(t *testing.T) {
testHookCanceledDial = oldTestHookCanceledDial
}()
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
listenerDone := make(chan struct{})
go func() {
defer close(listenerDone)
diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go
index 1bbe396..a779c37 100644
--- a/libgo/go/net/dnsclient.go
+++ b/libgo/go/net/dnsclient.go
@@ -5,6 +5,7 @@
package net
import (
+ "internal/bytealg"
"internal/itoa"
"sort"
@@ -75,6 +76,11 @@ func equalASCIIName(x, y dnsmessage.Name) bool {
// (currently restricted to hostname-compatible "preferred name" LDH labels and
// SRV-like "underscore labels"; see golang.org/issue/12421).
func isDomainName(s string) bool {
+ // The root domain name is valid. See golang.org/issue/45715.
+ if s == "." {
+ return true
+ }
+
// See RFC 1035, RFC 3696.
// Presentation format has dots before every label except the first, and the
// terminal empty label is optional here because we assume fully-qualified
@@ -136,18 +142,11 @@ func isDomainName(s string) bool {
// It's hard to tell so we settle on the heuristic that names without dots
// (like "localhost" or "myhost") do not get trailing dots, but any other
// names do.
-func absDomainName(b []byte) string {
- hasDots := false
- for _, x := range b {
- if x == '.' {
- hasDots = true
- break
- }
- }
- if hasDots && b[len(b)-1] != '.' {
- b = append(b, '.')
+func absDomainName(s string) string {
+ if bytealg.IndexByteString(s, '.') != -1 && s[len(s)-1] != '.' {
+ s += "."
}
- return string(b)
+ return s
}
// An SRV represents a single DNS SRV record.
diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go
index a326319..3278791e 100644
--- a/libgo/go/net/dnsclient_unix.go
+++ b/libgo/go/net/dnsclient_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
// DNS client: see RFC 1035.
// Has to be linked into package net for Dial.
diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go
index ce1a4f3..e34c0a5 100644
--- a/libgo/go/net/dnsclient_unix_test.go
+++ b/libgo/go/net/dnsclient_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
@@ -2121,3 +2120,44 @@ func TestNullMX(t *testing.T) {
t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0])
}
}
+
+func TestRootNS(t *testing.T) {
+ // See https://golang.org/issue/45715.
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeNS,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.NSResource{
+ NS: dnsmessage.MustNewName("i.root-servers.net."),
+ },
+ },
+ },
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ rrset, err := r.LookupNS(context.Background(), ".")
+ if err != nil {
+ t.Fatalf("LookupNS: %v", err)
+ }
+ if want := []*NS{&NS{Host: "i.root-servers.net."}}; !reflect.DeepEqual(rrset, want) {
+ records := []string{}
+ for _, rr := range rrset {
+ records = append(records, fmt.Sprintf("%v", rr))
+ }
+ t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0])
+ }
+}
diff --git a/libgo/go/net/dnsconfig_unix.go b/libgo/go/net/dnsconfig_unix.go
index 4b11602..37f3cce 100644
--- a/libgo/go/net/dnsconfig_unix.go
+++ b/libgo/go/net/dnsconfig_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
// Read system DNS config from /etc/resolv.conf
diff --git a/libgo/go/net/dnsconfig_unix_test.go b/libgo/go/net/dnsconfig_unix_test.go
index 59e21d6..652a68f 100644
--- a/libgo/go/net/dnsconfig_unix_test.go
+++ b/libgo/go/net/dnsconfig_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/dnsname_test.go b/libgo/go/net/dnsname_test.go
index d851bf7..28b7c68 100644
--- a/libgo/go/net/dnsname_test.go
+++ b/libgo/go/net/dnsname_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
diff --git a/libgo/go/net/error_plan9_test.go b/libgo/go/net/error_plan9_test.go
index d7c7f14..1270af1 100644
--- a/libgo/go/net/error_plan9_test.go
+++ b/libgo/go/net/error_plan9_test.go
@@ -17,3 +17,7 @@ func isPlatformError(err error) bool {
_, ok := err.(syscall.ErrorString)
return ok
}
+
+func isENOBUFS(err error) bool {
+ return false // ENOBUFS is Unix-specific
+}
diff --git a/libgo/go/net/error_posix.go b/libgo/go/net/error_posix.go
index 017f2cb..94c73cc 100644
--- a/libgo/go/net/error_posix.go
+++ b/libgo/go/net/error_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/error_posix_test.go b/libgo/go/net/error_posix_test.go
index ea52a45..081176f 100644
--- a/libgo/go/net/error_posix_test.go
+++ b/libgo/go/net/error_posix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !plan9
-// +build !plan9
package net
diff --git a/libgo/go/net/error_test.go b/libgo/go/net/error_test.go
index c304390..4a191673 100644
--- a/libgo/go/net/error_test.go
+++ b/libgo/go/net/error_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -554,10 +553,7 @@ third:
}
func TestCloseError(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
c, err := Dial(ln.Addr().Network(), ln.Addr().String())
if err != nil {
@@ -665,10 +661,7 @@ func TestAcceptError(t *testing.T) {
c.Close()
}
}
- ls, err := newLocalServer("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, "tcp")
if err := ls.buildup(handler); err != nil {
ls.teardown()
t.Fatal(err)
@@ -774,10 +767,7 @@ func TestFileError(t *testing.T) {
t.Error("should fail")
}
- ln, err = newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln = newLocalListener(t, "tcp")
for i := 0; i < 3; i++ {
f, err := ln.(*TCPListener).File()
diff --git a/libgo/go/net/error_unix.go b/libgo/go/net/error_unix.go
index 3de4e76..775e4a0 100644
--- a/libgo/go/net/error_unix.go
+++ b/libgo/go/net/error_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || js || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd js linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/error_unix_test.go b/libgo/go/net/error_unix_test.go
index 533a45e..291a723 100644
--- a/libgo/go/net/error_unix_test.go
+++ b/libgo/go/net/error_unix_test.go
@@ -3,11 +3,11 @@
// license that can be found in the LICENSE file.
//go:build !plan9 && !windows
-// +build !plan9,!windows
package net
import (
+ "errors"
"os"
"syscall"
)
@@ -33,3 +33,7 @@ func samePlatformError(err, want error) bool {
}
return err == want
}
+
+func isENOBUFS(err error) bool {
+ return errors.Is(err, syscall.ENOBUFS)
+}
diff --git a/libgo/go/net/error_windows_test.go b/libgo/go/net/error_windows_test.go
index 834a9de..25825f9 100644
--- a/libgo/go/net/error_windows_test.go
+++ b/libgo/go/net/error_windows_test.go
@@ -4,7 +4,10 @@
package net
-import "syscall"
+import (
+ "errors"
+ "syscall"
+)
var (
errTimedout = syscall.ETIMEDOUT
@@ -17,3 +20,10 @@ func isPlatformError(err error) bool {
_, ok := err.(syscall.Errno)
return ok
}
+
+func isENOBUFS(err error) bool {
+ // syscall.ENOBUFS is a completely made-up value on Windows: we don't expect
+ // a real system call to ever actually return it. However, since it is already
+ // defined in the syscall package we may as well check for it.
+ return errors.Is(err, syscall.ENOBUFS)
+}
diff --git a/libgo/go/net/example_test.go b/libgo/go/net/example_test.go
index 72c7183..2c045d7 100644
--- a/libgo/go/net/example_test.go
+++ b/libgo/go/net/example_test.go
@@ -124,6 +124,176 @@ func ExampleIP_DefaultMask() {
// ffffff00
}
+func ExampleIP_Equal() {
+ ipv4DNS := net.ParseIP("8.8.8.8")
+ ipv4Lo := net.ParseIP("127.0.0.1")
+ ipv6DNS := net.ParseIP("0:0:0:0:0:FFFF:0808:0808")
+
+ fmt.Println(ipv4DNS.Equal(ipv4DNS))
+ fmt.Println(ipv4DNS.Equal(ipv4Lo))
+ fmt.Println(ipv4DNS.Equal(ipv6DNS))
+
+ // Output:
+ // true
+ // false
+ // true
+}
+
+func ExampleIP_IsGlobalUnicast() {
+ ipv6Global := net.ParseIP("2000::")
+ ipv6UniqLocal := net.ParseIP("2000::")
+ ipv6Multi := net.ParseIP("FF00::")
+
+ ipv4Private := net.ParseIP("10.255.0.0")
+ ipv4Public := net.ParseIP("8.8.8.8")
+ ipv4Broadcast := net.ParseIP("255.255.255.255")
+
+ fmt.Println(ipv6Global.IsGlobalUnicast())
+ fmt.Println(ipv6UniqLocal.IsGlobalUnicast())
+ fmt.Println(ipv6Multi.IsGlobalUnicast())
+
+ fmt.Println(ipv4Private.IsGlobalUnicast())
+ fmt.Println(ipv4Public.IsGlobalUnicast())
+ fmt.Println(ipv4Broadcast.IsGlobalUnicast())
+
+ // Output:
+ // true
+ // true
+ // false
+ // true
+ // true
+ // false
+}
+
+func ExampleIP_IsInterfaceLocalMulticast() {
+ ipv6InterfaceLocalMulti := net.ParseIP("ff01::1")
+ ipv6Global := net.ParseIP("2000::")
+ ipv4 := net.ParseIP("255.0.0.0")
+
+ fmt.Println(ipv6InterfaceLocalMulti.IsInterfaceLocalMulticast())
+ fmt.Println(ipv6Global.IsInterfaceLocalMulticast())
+ fmt.Println(ipv4.IsInterfaceLocalMulticast())
+
+ // Output:
+ // true
+ // false
+ // false
+}
+
+func ExampleIP_IsLinkLocalMulticast() {
+ ipv6LinkLocalMulti := net.ParseIP("ff02::2")
+ ipv6LinkLocalUni := net.ParseIP("fe80::")
+ ipv4LinkLocalMulti := net.ParseIP("224.0.0.0")
+ ipv4LinkLocalUni := net.ParseIP("169.254.0.0")
+
+ fmt.Println(ipv6LinkLocalMulti.IsLinkLocalMulticast())
+ fmt.Println(ipv6LinkLocalUni.IsLinkLocalMulticast())
+ fmt.Println(ipv4LinkLocalMulti.IsLinkLocalMulticast())
+ fmt.Println(ipv4LinkLocalUni.IsLinkLocalMulticast())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
+func ExampleIP_IsLinkLocalUnicast() {
+ ipv6LinkLocalUni := net.ParseIP("fe80::")
+ ipv6Global := net.ParseIP("2000::")
+ ipv4LinkLocalUni := net.ParseIP("169.254.0.0")
+ ipv4LinkLocalMulti := net.ParseIP("224.0.0.0")
+
+ fmt.Println(ipv6LinkLocalUni.IsLinkLocalUnicast())
+ fmt.Println(ipv6Global.IsLinkLocalUnicast())
+ fmt.Println(ipv4LinkLocalUni.IsLinkLocalUnicast())
+ fmt.Println(ipv4LinkLocalMulti.IsLinkLocalUnicast())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
+func ExampleIP_IsLoopback() {
+ ipv6Lo := net.ParseIP("::1")
+ ipv6 := net.ParseIP("ff02::1")
+ ipv4Lo := net.ParseIP("127.0.0.0")
+ ipv4 := net.ParseIP("128.0.0.0")
+
+ fmt.Println(ipv6Lo.IsLoopback())
+ fmt.Println(ipv6.IsLoopback())
+ fmt.Println(ipv4Lo.IsLoopback())
+ fmt.Println(ipv4.IsLoopback())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
+func ExampleIP_IsMulticast() {
+ ipv6Multi := net.ParseIP("FF00::")
+ ipv6LinkLocalMulti := net.ParseIP("ff02::1")
+ ipv6Lo := net.ParseIP("::1")
+ ipv4Multi := net.ParseIP("239.0.0.0")
+ ipv4LinkLocalMulti := net.ParseIP("224.0.0.0")
+ ipv4Lo := net.ParseIP("127.0.0.0")
+
+ fmt.Println(ipv6Multi.IsMulticast())
+ fmt.Println(ipv6LinkLocalMulti.IsMulticast())
+ fmt.Println(ipv6Lo.IsMulticast())
+ fmt.Println(ipv4Multi.IsMulticast())
+ fmt.Println(ipv4LinkLocalMulti.IsMulticast())
+ fmt.Println(ipv4Lo.IsMulticast())
+
+ // Output:
+ // true
+ // true
+ // false
+ // true
+ // true
+ // false
+}
+
+func ExampleIP_IsPrivate() {
+ ipv6Private := net.ParseIP("fc00::")
+ ipv6Public := net.ParseIP("fe00::")
+ ipv4Private := net.ParseIP("10.255.0.0")
+ ipv4Public := net.ParseIP("11.0.0.0")
+
+ fmt.Println(ipv6Private.IsPrivate())
+ fmt.Println(ipv6Public.IsPrivate())
+ fmt.Println(ipv4Private.IsPrivate())
+ fmt.Println(ipv4Public.IsPrivate())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
+func ExampleIP_IsUnspecified() {
+ ipv6Unspecified := net.ParseIP("::")
+ ipv6Specified := net.ParseIP("fe00::")
+ ipv4Unspecified := net.ParseIP("0.0.0.0")
+ ipv4Specified := net.ParseIP("8.8.8.8")
+
+ fmt.Println(ipv6Unspecified.IsUnspecified())
+ fmt.Println(ipv6Specified.IsUnspecified())
+ fmt.Println(ipv4Unspecified.IsUnspecified())
+ fmt.Println(ipv4Specified.IsUnspecified())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
func ExampleIP_Mask() {
ipv4Addr := net.ParseIP("192.0.2.1")
// This mask corresponds to a /24 subnet for IPv4.
@@ -140,6 +310,42 @@ func ExampleIP_Mask() {
// 2001:db8::
}
+func ExampleIP_String() {
+ ipv6 := net.IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+ ipv4 := net.IPv4(10, 255, 0, 0)
+
+ fmt.Println(ipv6.String())
+ fmt.Println(ipv4.String())
+
+ // Output:
+ // fc00::
+ // 10.255.0.0
+}
+
+func ExampleIP_To16() {
+ ipv6 := net.IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+ ipv4 := net.IPv4(10, 255, 0, 0)
+
+ fmt.Println(ipv6.To16())
+ fmt.Println(ipv4.To16())
+
+ // Output:
+ // fc00::
+ // 10.255.0.0
+}
+
+func ExampleIP_to4() {
+ ipv6 := net.IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+ ipv4 := net.IPv4(10, 255, 0, 0)
+
+ fmt.Println(ipv6.To4())
+ fmt.Println(ipv4.To4())
+
+ // Output:
+ // <nil>
+ // 10.255.0.0
+}
+
func ExampleCIDRMask() {
// This mask corresponds to a /31 subnet for IPv4.
fmt.Println(net.CIDRMask(31, 32))
diff --git a/libgo/go/net/external_test.go b/libgo/go/net/external_test.go
index b8753cc..3a97011 100644
--- a/libgo/go/net/external_test.go
+++ b/libgo/go/net/external_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
diff --git a/libgo/go/net/fcntl_libc_test.go b/libgo/go/net/fcntl_libc_test.go
index 02511c5..f59a1aa 100644
--- a/libgo/go/net/fcntl_libc_test.go
+++ b/libgo/go/net/fcntl_libc_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || solaris
-// +build aix darwin solaris
package net
diff --git a/libgo/go/net/fcntl_syscall_test.go b/libgo/go/net/fcntl_syscall_test.go
index 59ba1a1..58cacc4 100644
--- a/libgo/go/net/fcntl_syscall_test.go
+++ b/libgo/go/net/fcntl_syscall_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || openbsd
-// +build dragonfly freebsd linux netbsd openbsd
package net
diff --git a/libgo/go/net/fd_posix.go b/libgo/go/net/fd_posix.go
index a0f1f5a..466ccce 100644
--- a/libgo/go/net/fd_posix.go
+++ b/libgo/go/net/fd_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows
package net
@@ -63,6 +62,17 @@ func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
runtime.KeepAlive(fd)
return n, sa, wrapSyscallError(readFromSyscallName, err)
}
+func (fd *netFD) readFromInet4(p []byte, from *syscall.SockaddrInet4) (n int, err error) {
+ n, err = fd.pfd.ReadFromInet4(p, from)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(readFromSyscallName, err)
+}
+
+func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, err error) {
+ n, err = fd.pfd.ReadFromInet6(p, from)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(readFromSyscallName, err)
+}
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
@@ -70,6 +80,18 @@ func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int
return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
}
+func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
+ n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa)
+ runtime.KeepAlive(fd)
+ return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
+}
+
+func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
+ n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa)
+ runtime.KeepAlive(fd)
+ return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
+}
+
func (fd *netFD) Write(p []byte) (nn int, err error) {
nn, err = fd.pfd.Write(p)
runtime.KeepAlive(fd)
@@ -82,12 +104,36 @@ func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
return n, wrapSyscallError(writeToSyscallName, err)
}
+func (fd *netFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
+ n, err = fd.pfd.WriteToInet4(p, sa)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(writeToSyscallName, err)
+}
+
+func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
+ n, err = fd.pfd.WriteToInet6(p, sa)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(writeToSyscallName, err)
+}
+
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
}
+func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
+ n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa)
+ runtime.KeepAlive(fd)
+ return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
+}
+
+func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
+ n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa)
+ runtime.KeepAlive(fd)
+ return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
+}
+
func (fd *netFD) SetDeadline(t time.Time) error {
return fd.pfd.SetDeadline(t)
}
diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go
index e2db165..394e1c7 100644
--- a/libgo/go/net/fd_unix.go
+++ b/libgo/go/net/fd_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
@@ -92,12 +91,12 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa sysc
}
// Start the "interrupter" goroutine, if this context might be canceled.
- // (The background context cannot)
//
// The interrupter goroutine waits for the context to be done and
// interrupts the dial (by altering the fd's write deadline, which
// wakes up waitWrite).
- if ctx != context.Background() {
+ ctxDone := ctx.Done()
+ if ctxDone != nil {
// Wait for the interrupter goroutine to exit before returning
// from connect.
done := make(chan struct{})
@@ -117,7 +116,7 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa sysc
}()
go func() {
select {
- case <-ctx.Done():
+ case <-ctxDone:
// Force the runtime's poller to immediately give up
// waiting for writability, unblocking waitWrite
// below.
@@ -141,7 +140,7 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa sysc
// details.
if err := fd.pfd.WaitWrite(); err != nil {
select {
- case <-ctx.Done():
+ case <-ctxDone:
return nil, mapErr(ctx.Err())
default:
}
diff --git a/libgo/go/net/file_stub.go b/libgo/go/net/file_stub.go
index 9f988fe..91df926 100644
--- a/libgo/go/net/file_stub.go
+++ b/libgo/go/net/file_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build js && wasm
-// +build js,wasm
package net
diff --git a/libgo/go/net/file_test.go b/libgo/go/net/file_test.go
index a70ef1b..ea2a218 100644
--- a/libgo/go/net/file_test.go
+++ b/libgo/go/net/file_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -45,10 +44,7 @@ func TestFileConn(t *testing.T) {
var network, address string
switch tt.network {
case "udp":
- c, err := newLocalPacketListener(tt.network)
- if err != nil {
- t.Fatal(err)
- }
+ c := newLocalPacketListener(t, tt.network)
defer c.Close()
network = c.LocalAddr().Network()
address = c.LocalAddr().String()
@@ -62,10 +58,7 @@ func TestFileConn(t *testing.T) {
var b [1]byte
c.Read(b[:])
}
- ls, err := newLocalServer(tt.network)
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, tt.network)
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -149,17 +142,17 @@ func TestFileListener(t *testing.T) {
continue
}
- ln1, err := newLocalListener(tt.network)
- if err != nil {
- t.Fatal(err)
- }
+ ln1 := newLocalListener(t, tt.network)
switch tt.network {
case "unix", "unixpacket":
defer os.Remove(ln1.Addr().String())
}
addr := ln1.Addr()
- var f *os.File
+ var (
+ f *os.File
+ err error
+ )
switch ln1 := ln1.(type) {
case *TCPListener:
f, err = ln1.File()
@@ -241,17 +234,17 @@ func TestFilePacketConn(t *testing.T) {
continue
}
- c1, err := newLocalPacketListener(tt.network)
- if err != nil {
- t.Fatal(err)
- }
+ c1 := newLocalPacketListener(t, tt.network)
switch tt.network {
case "unixgram":
defer os.Remove(c1.LocalAddr().String())
}
addr := c1.LocalAddr()
- var f *os.File
+ var (
+ f *os.File
+ err error
+ )
switch c1 := c1.(type) {
case *UDPConn:
f, err = c1.File()
@@ -315,10 +308,7 @@ func TestFileCloseRace(t *testing.T) {
c.Read(b[:])
}
- ls, err := newLocalServer("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, "tcp")
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
diff --git a/libgo/go/net/file_unix.go b/libgo/go/net/file_unix.go
index d36a881..afb1d98 100644
--- a/libgo/go/net/file_unix.go
+++ b/libgo/go/net/file_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/hook_unix.go b/libgo/go/net/hook_unix.go
index 618c6c2..5629476 100644
--- a/libgo/go/net/hook_unix.go
+++ b/libgo/go/net/hook_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/hosts.go b/libgo/go/net/hosts.go
index 5c560f3..e604031 100644
--- a/libgo/go/net/hosts.go
+++ b/libgo/go/net/hosts.go
@@ -82,10 +82,10 @@ func readHosts() {
continue
}
for i := 1; i < len(f); i++ {
- name := absDomainName([]byte(f[i]))
+ name := absDomainName(f[i])
h := []byte(f[i])
lowerASCIIBytes(h)
- key := absDomainName(h)
+ key := absDomainName(string(h))
hs[key] = append(hs[key], addr)
is[addr] = append(is[addr], name)
}
@@ -106,11 +106,12 @@ func lookupStaticHost(host string) []string {
defer hosts.Unlock()
readHosts()
if len(hosts.byName) != 0 {
- // TODO(jbd,bradfitz): avoid this alloc if host is already all lowercase?
- // or linear scan the byName map if it's small enough?
- lowerHost := []byte(host)
- lowerASCIIBytes(lowerHost)
- if ips, ok := hosts.byName[absDomainName(lowerHost)]; ok {
+ if hasUpperCase(host) {
+ lowerHost := []byte(host)
+ lowerASCIIBytes(lowerHost)
+ host = string(lowerHost)
+ }
+ if ips, ok := hosts.byName[absDomainName(host)]; ok {
ipsCp := make([]string, len(ips))
copy(ipsCp, ips)
return ipsCp
diff --git a/libgo/go/net/hosts_test.go b/libgo/go/net/hosts_test.go
index 19c4399..7291914 100644
--- a/libgo/go/net/hosts_test.go
+++ b/libgo/go/net/hosts_test.go
@@ -70,7 +70,7 @@ func TestLookupStaticHost(t *testing.T) {
}
func testStaticHost(t *testing.T, hostsPath string, ent staticHostEntry) {
- ins := []string{ent.in, absDomainName([]byte(ent.in)), strings.ToLower(ent.in), strings.ToUpper(ent.in)}
+ ins := []string{ent.in, absDomainName(ent.in), strings.ToLower(ent.in), strings.ToUpper(ent.in)}
for _, in := range ins {
addrs := lookupStaticHost(in)
if !reflect.DeepEqual(addrs, ent.out) {
@@ -141,7 +141,7 @@ func TestLookupStaticAddr(t *testing.T) {
func testStaticAddr(t *testing.T, hostsPath string, ent staticHostEntry) {
hosts := lookupStaticAddr(ent.in)
for i := range ent.out {
- ent.out[i] = absDomainName([]byte(ent.out[i]))
+ ent.out[i] = absDomainName(ent.out[i])
}
if !reflect.DeepEqual(hosts, ent.out) {
t.Errorf("%s, lookupStaticAddr(%s) = %v; want %v", hostsPath, ent.in, hosts, ent.out)
diff --git a/libgo/go/net/http/cgi/child.go b/libgo/go/net/http/cgi/child.go
index 0114da3..bdb35a6 100644
--- a/libgo/go/net/http/cgi/child.go
+++ b/libgo/go/net/http/cgi/child.go
@@ -39,8 +39,8 @@ func Request() (*http.Request, error) {
func envMap(env []string) map[string]string {
m := make(map[string]string)
for _, kv := range env {
- if idx := strings.Index(kv, "="); idx != -1 {
- m[kv[:idx]] = kv[idx+1:]
+ if k, v, ok := strings.Cut(kv, "="); ok {
+ m[k] = v
}
}
return m
diff --git a/libgo/go/net/http/cgi/host.go b/libgo/go/net/http/cgi/host.go
index eff67ca..95b2e13 100644
--- a/libgo/go/net/http/cgi/host.go
+++ b/libgo/go/net/http/cgi/host.go
@@ -273,12 +273,11 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
break
}
headerLines++
- parts := strings.SplitN(string(line), ":", 2)
- if len(parts) < 2 {
+ header, val, ok := strings.Cut(string(line), ":")
+ if !ok {
h.printf("cgi: bogus header line: %s", string(line))
continue
}
- header, val := parts[0], parts[1]
if !httpguts.ValidHeaderFieldName(header) {
h.printf("cgi: invalid header name: %q", header)
continue
@@ -351,7 +350,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
-func (h *Handler) printf(format string, v ...interface{}) {
+func (h *Handler) printf(format string, v ...any) {
if h.Logger != nil {
h.Logger.Printf(format, v...)
} else {
diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go
index 9f1716b..1b72f7e 100644
--- a/libgo/go/net/http/cgi/host_test.go
+++ b/libgo/go/net/http/cgi/host_test.go
@@ -62,12 +62,12 @@ readlines:
}
linesRead++
trimmedLine := strings.TrimRight(line, "\r\n")
- split := strings.SplitN(trimmedLine, "=", 2)
- if len(split) != 2 {
- t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v",
- len(split), linesRead, line, m)
+ k, v, ok := strings.Cut(trimmedLine, "=")
+ if !ok {
+ t.Fatalf("Unexpected response from invalid line number %v: %q; existing map=%v",
+ linesRead, line, m)
}
- m[split[0]] = split[1]
+ m[k] = v
}
for key, expected := range expectedMap {
diff --git a/libgo/go/net/http/cgi/posix_test.go b/libgo/go/net/http/cgi/posix_test.go
index bc58ea9..49b9470 100644
--- a/libgo/go/net/http/cgi/posix_test.go
+++ b/libgo/go/net/http/cgi/posix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !plan9
-// +build !plan9
package cgi
diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go
index 4d380c6..22db96b 100644
--- a/libgo/go/net/http/client.go
+++ b/libgo/go/net/http/client.go
@@ -965,7 +965,6 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
if err == nil {
return n, nil
}
- b.stop()
if err == io.EOF {
return n, err
}
diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go
index 01d605c..e91d526 100644
--- a/libgo/go/net/http/client_test.go
+++ b/libgo/go/net/http/client_test.go
@@ -13,6 +13,7 @@ import (
"encoding/base64"
"errors"
"fmt"
+ "internal/testenv"
"io"
"log"
"net"
@@ -21,6 +22,7 @@ import (
"net/http/httptest"
"net/url"
"reflect"
+ "runtime"
"strconv"
"strings"
"sync"
@@ -431,11 +433,10 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa
if v := urlQuery.Get("code"); v != "" {
location := ts.URL
if final := urlQuery.Get("next"); final != "" {
- splits := strings.Split(final, ",")
- first, rest := splits[0], splits[1:]
+ first, rest, _ := strings.Cut(final, ",")
location = fmt.Sprintf("%s?code=%s", location, first)
- if len(rest) > 0 {
- location = fmt.Sprintf("%s&next=%s", location, strings.Join(rest, ","))
+ if rest != "" {
+ location = fmt.Sprintf("%s&next=%s", location, rest)
}
}
code, _ := strconv.Atoi(v)
@@ -746,7 +747,7 @@ func (j *RecordingJar) Cookies(u *url.URL) []*Cookie {
return nil
}
-func (j *RecordingJar) logf(format string, args ...interface{}) {
+func (j *RecordingJar) logf(format string, args ...any) {
j.mu.Lock()
defer j.mu.Unlock()
fmt.Fprintf(&j.log, format, args...)
@@ -1206,64 +1207,80 @@ func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) }
func testClientTimeout(t *testing.T, h2 bool) {
setParallel(t)
defer afterTest(t)
- testDone := make(chan struct{}) // closed in defer below
- sawRoot := make(chan bool, 1)
- sawSlow := make(chan bool, 1)
+ var (
+ mu sync.Mutex
+ nonce string // a unique per-request string
+ sawSlowNonce bool // true if the handler saw /slow?nonce=<nonce>
+ )
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ _ = r.ParseForm()
if r.URL.Path == "/" {
- sawRoot <- true
- Redirect(w, r, "/slow", StatusFound)
+ Redirect(w, r, "/slow?nonce="+r.Form.Get("nonce"), StatusFound)
return
}
if r.URL.Path == "/slow" {
- sawSlow <- true
+ mu.Lock()
+ if r.Form.Get("nonce") == nonce {
+ sawSlowNonce = true
+ } else {
+ t.Logf("mismatched nonce: received %s, want %s", r.Form.Get("nonce"), nonce)
+ }
+ mu.Unlock()
+
w.Write([]byte("Hello"))
w.(Flusher).Flush()
- <-testDone
+ <-r.Context().Done()
return
}
}))
defer cst.close()
- defer close(testDone) // before cst.close, to unblock /slow handler
- // 200ms should be long enough to get a normal request (the /
- // handler), but not so long that it makes the test slow.
- const timeout = 200 * time.Millisecond
- cst.c.Timeout = timeout
-
- res, err := cst.c.Get(cst.ts.URL)
- if err != nil {
- if strings.Contains(err.Error(), "Client.Timeout") {
- t.Skipf("host too slow to get fast resource in %v", timeout)
+ // Try to trigger a timeout after reading part of the response body.
+ // The initial timeout is emprically usually long enough on a decently fast
+ // machine, but if we undershoot we'll retry with exponentially longer
+ // timeouts until the test either passes or times out completely.
+ // This keeps the test reasonably fast in the typical case but allows it to
+ // also eventually succeed on arbitrarily slow machines.
+ timeout := 10 * time.Millisecond
+ nextNonce := 0
+ for ; ; timeout *= 2 {
+ if timeout <= 0 {
+ // The only way we can feasibly hit this while the test is running is if
+ // the request fails without actually waiting for the timeout to occur.
+ t.Fatalf("timeout overflow")
+ }
+ if deadline, ok := t.Deadline(); ok && !time.Now().Add(timeout).Before(deadline) {
+ t.Fatalf("failed to produce expected timeout before test deadline")
+ }
+ t.Logf("attempting test with timeout %v", timeout)
+ cst.c.Timeout = timeout
+
+ mu.Lock()
+ nonce = fmt.Sprint(nextNonce)
+ nextNonce++
+ sawSlowNonce = false
+ mu.Unlock()
+ res, err := cst.c.Get(cst.ts.URL + "/?nonce=" + nonce)
+ if err != nil {
+ if strings.Contains(err.Error(), "Client.Timeout") {
+ // Timed out before handler could respond.
+ t.Logf("timeout before response received")
+ continue
+ }
+ t.Fatal(err)
}
- t.Fatal(err)
- }
-
- select {
- case <-sawRoot:
- // good.
- default:
- t.Fatal("handler never got / request")
- }
- select {
- case <-sawSlow:
- // good.
- default:
- t.Fatal("handler never got /slow request")
- }
+ mu.Lock()
+ ok := sawSlowNonce
+ mu.Unlock()
+ if !ok {
+ t.Fatal("handler never got /slow request, but client returned response")
+ }
- errc := make(chan error, 1)
- go func() {
- _, err := io.ReadAll(res.Body)
- errc <- err
+ _, err = io.ReadAll(res.Body)
res.Body.Close()
- }()
- const failTime = 5 * time.Second
- select {
- case err := <-errc:
if err == nil {
t.Fatal("expected error from ReadAll")
}
@@ -1274,10 +1291,13 @@ func testClientTimeout(t *testing.T, h2 bool) {
t.Errorf("net.Error.Timeout = false; want true")
}
if got := ne.Error(); !strings.Contains(got, "(Client.Timeout") {
+ if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") {
+ testenv.SkipFlaky(t, 43120)
+ }
t.Errorf("error string = %q; missing timeout substring", got)
}
- case <-time.After(failTime):
- t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout)
+
+ break
}
}
@@ -1319,6 +1339,9 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) {
t.Error("net.Error.Timeout = false; want true")
}
if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
+ if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") {
+ testenv.SkipFlaky(t, 43120)
+ }
t.Errorf("error string = %q; missing timeout substring", got)
}
}
@@ -1353,6 +1376,33 @@ func TestClientTimeoutCancel(t *testing.T) {
}
}
+func TestClientTimeoutDoesNotExpire_h1(t *testing.T) { testClientTimeoutDoesNotExpire(t, h1Mode) }
+func TestClientTimeoutDoesNotExpire_h2(t *testing.T) { testClientTimeoutDoesNotExpire(t, h2Mode) }
+
+// Issue 49366: if Client.Timeout is set but not hit, no error should be returned.
+func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) {
+ setParallel(t)
+ defer afterTest(t)
+
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("body"))
+ }))
+ defer cst.close()
+
+ cst.c.Timeout = 1 * time.Hour
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err = io.Copy(io.Discard, res.Body); err != nil {
+ t.Fatalf("io.Copy(io.Discard, res.Body) = %v, want nil", err)
+ }
+ if err = res.Body.Close(); err != nil {
+ t.Fatalf("res.Body.Close() = %v, want nil", err)
+ }
+}
+
func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) }
func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) }
func testClientRedirectEatsBody(t *testing.T, h2 bool) {
@@ -2082,3 +2132,47 @@ func (b *issue40382Body) Close() error {
}
return nil
}
+
+func TestProbeZeroLengthBody(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ reqc := make(chan struct{})
+ cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) {
+ close(reqc)
+ if _, err := io.Copy(w, r.Body); err != nil {
+ t.Errorf("error copying request body: %v", err)
+ }
+ }))
+ defer cst.close()
+
+ bodyr, bodyw := io.Pipe()
+ var gotBody string
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ req, _ := NewRequest("GET", cst.ts.URL, bodyr)
+ res, err := cst.c.Do(req)
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ gotBody = string(b)
+ }()
+
+ select {
+ case <-reqc:
+ // Request should be sent after trying to probe the request body for 200ms.
+ case <-time.After(60 * time.Second):
+ t.Errorf("request not sent after 60s")
+ }
+
+ // Write the request body and wait for the request to complete.
+ const content = "body"
+ bodyw.Write([]byte(content))
+ bodyw.Close()
+ wg.Wait()
+ if gotBody != content {
+ t.Fatalf("server got body %q, want %q", gotBody, content)
+ }
+}
diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go
index 42207ac..44d70f0 100644
--- a/libgo/go/net/http/clientserver_test.go
+++ b/libgo/go/net/http/clientserver_test.go
@@ -81,7 +81,7 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
}
}
-func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest {
+func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest {
if h2 {
CondSkipHTTP2(t)
}
@@ -189,7 +189,7 @@ type h12Compare struct {
ReqFunc reqFunc // optional
CheckResponse func(proto string, res *Response) // optional
EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
- Opts []interface{}
+ Opts []any
}
func (tt h12Compare) reqFunc() reqFunc {
@@ -441,7 +441,7 @@ func TestH12_AutoGzip(t *testing.T) {
func TestH12_AutoGzip_Disabled(t *testing.T) {
h12Compare{
- Opts: []interface{}{
+ Opts: []any{
func(tr *Transport) { tr.DisableCompression = true },
},
Handler: func(w ResponseWriter, r *Request) {
@@ -1172,7 +1172,7 @@ func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) {
func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) {
testInterruptWithPanic(t, h2Mode, ErrAbortHandler)
}
-func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) {
+func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) {
setParallel(t)
const msg = "hello"
defer afterTest(t)
@@ -1522,7 +1522,7 @@ func TestBidiStreamReverseProxy(t *testing.T) {
}))
defer proxy.close()
- bodyRes := make(chan interface{}, 1) // error or hash.Hash
+ bodyRes := make(chan any, 1) // error or hash.Hash
pr, pw := io.Pipe()
req, _ := NewRequest("PUT", proxy.ts.URL, pr)
const size = 4 << 20
@@ -1586,3 +1586,37 @@ func TestH12_WebSocketUpgrade(t *testing.T) {
},
}.run(t)
}
+
+func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) }
+func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) }
+
+func testIdentityTransferEncoding(t *testing.T, h2 bool) {
+ setParallel(t)
+ defer afterTest(t)
+
+ const body = "body"
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotBody, _ := io.ReadAll(r.Body)
+ if got, want := string(gotBody), body; got != want {
+ t.Errorf("got request body = %q; want %q", got, want)
+ }
+ w.Header().Set("Transfer-Encoding", "identity")
+ w.WriteHeader(StatusOK)
+ w.(Flusher).Flush()
+ io.WriteString(w, body)
+ }))
+ defer cst.close()
+ req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ gotBody, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := string(gotBody), body; got != want {
+ t.Errorf("got response body = %q; want %q", got, want)
+ }
+}
diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go
index ca2c1c2..cb37f23 100644
--- a/libgo/go/net/http/cookie.go
+++ b/libgo/go/net/http/cookie.go
@@ -5,6 +5,8 @@
package http
import (
+ "errors"
+ "fmt"
"log"
"net"
"net/http/internal/ascii"
@@ -67,15 +69,14 @@ func readSetCookies(h Header) []*Cookie {
continue
}
parts[0] = textproto.TrimString(parts[0])
- j := strings.Index(parts[0], "=")
- if j < 0 {
+ name, value, ok := strings.Cut(parts[0], "=")
+ if !ok {
continue
}
- name, value := parts[0][:j], parts[0][j+1:]
if !isCookieNameValid(name) {
continue
}
- value, ok := parseCookieValue(value, true)
+ value, ok = parseCookieValue(value, true)
if !ok {
continue
}
@@ -90,10 +91,7 @@ func readSetCookies(h Header) []*Cookie {
continue
}
- attr, val := parts[i], ""
- if j := strings.Index(attr, "="); j >= 0 {
- attr, val = attr[:j], attr[j+1:]
- }
+ attr, val, _ := strings.Cut(parts[i], "=")
lowerAttr, isASCII := ascii.ToLower(attr)
if !isASCII {
continue
@@ -240,6 +238,37 @@ func (c *Cookie) String() string {
return b.String()
}
+// Valid reports whether the cookie is valid.
+func (c *Cookie) Valid() error {
+ if c == nil {
+ return errors.New("http: nil Cookie")
+ }
+ if !isCookieNameValid(c.Name) {
+ return errors.New("http: invalid Cookie.Name")
+ }
+ if !validCookieExpires(c.Expires) {
+ return errors.New("http: invalid Cookie.Expires")
+ }
+ for i := 0; i < len(c.Value); i++ {
+ if !validCookieValueByte(c.Value[i]) {
+ return fmt.Errorf("http: invalid byte %q in Cookie.Value", c.Value[i])
+ }
+ }
+ if len(c.Path) > 0 {
+ for i := 0; i < len(c.Path); i++ {
+ if !validCookiePathByte(c.Path[i]) {
+ return fmt.Errorf("http: invalid byte %q in Cookie.Path", c.Path[i])
+ }
+ }
+ }
+ if len(c.Domain) > 0 {
+ if !validCookieDomain(c.Domain) {
+ return errors.New("http: invalid Cookie.Domain")
+ }
+ }
+ return nil
+}
+
// readCookies parses all "Cookie" values from the header h and
// returns the successfully parsed Cookies.
//
@@ -256,19 +285,12 @@ func readCookies(h Header, filter string) []*Cookie {
var part string
for len(line) > 0 { // continue since we have rest
- if splitIndex := strings.Index(line, ";"); splitIndex > 0 {
- part, line = line[:splitIndex], line[splitIndex+1:]
- } else {
- part, line = line, ""
- }
+ part, line, _ = strings.Cut(line, ";")
part = textproto.TrimString(part)
- if len(part) == 0 {
+ if part == "" {
continue
}
- name, val := part, ""
- if j := strings.Index(part, "="); j >= 0 {
- name, val = name[:j], name[j+1:]
- }
+ name, val, _ := strings.Cut(part, "=")
if !isCookieNameValid(name) {
continue
}
@@ -379,7 +401,7 @@ func sanitizeCookieValue(v string) string {
if len(v) == 0 {
return v
}
- if strings.IndexByte(v, ' ') >= 0 || strings.IndexByte(v, ',') >= 0 {
+ if strings.ContainsAny(v, " ,") {
return `"` + v + `"`
}
return v
diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go
index 959713a..ccc5f98 100644
--- a/libgo/go/net/http/cookie_test.go
+++ b/libgo/go/net/http/cookie_test.go
@@ -360,7 +360,7 @@ var readSetCookiesTests = []struct {
// Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly, .ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}},
}
-func toJSON(v interface{}) string {
+func toJSON(v any) string {
b, err := json.Marshal(v)
if err != nil {
return fmt.Sprintf("%#v", v)
@@ -529,6 +529,31 @@ func TestCookieSanitizePath(t *testing.T) {
}
}
+func TestCookieValid(t *testing.T) {
+ tests := []struct {
+ cookie *Cookie
+ valid bool
+ }{
+ {nil, false},
+ {&Cookie{Name: ""}, false},
+ {&Cookie{Name: "invalid-expires"}, false},
+ {&Cookie{Name: "invalid-value", Value: "foo\"bar"}, false},
+ {&Cookie{Name: "invalid-path", Path: "/foo;bar/"}, false},
+ {&Cookie{Name: "invalid-domain", Domain: "example.com:80"}, false},
+ {&Cookie{Name: "valid", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0)}, true},
+ }
+
+ for _, tt := range tests {
+ err := tt.cookie.Valid()
+ if err != nil && tt.valid {
+ t.Errorf("%#v.Valid() returned error %v; want nil", tt.cookie, err)
+ }
+ if err == nil && !tt.valid {
+ t.Errorf("%#v.Valid() returned nil; want error", tt.cookie)
+ }
+ }
+}
+
func BenchmarkCookieString(b *testing.B) {
const wantCookieString = `cookie-9=i3e01nf61b6t23bvfmplnanol3; Path=/restricted/; Domain=example.com; Expires=Tue, 10 Nov 2009 23:00:00 GMT; Max-Age=3600`
c := &Cookie{
diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go
index 096a6d3..a849327 100644
--- a/libgo/go/net/http/export_test.go
+++ b/libgo/go/net/http/export_test.go
@@ -88,12 +88,7 @@ func SetPendingDialHooks(before, after func()) {
func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
-func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
- ctx, cancel := context.WithCancel(context.Background())
- go func() {
- <-ch
- cancel()
- }()
+func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
return &timeoutHandler{
handler: handler,
testContext: ctx,
diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go
index 57e731e..6caee9e 100644
--- a/libgo/go/net/http/fs.go
+++ b/libgo/go/net/http/fs.go
@@ -42,20 +42,20 @@ import (
// An empty Dir is treated as ".".
type Dir string
-// mapDirOpenError maps the provided non-nil error from opening name
+// mapOpenError maps the provided non-nil error from opening name
// to a possibly better non-nil error. In particular, it turns OS-specific errors
-// about opening files in non-directories into fs.ErrNotExist. See Issue 18984.
-func mapDirOpenError(originalErr error, name string) error {
+// about opening files in non-directories into fs.ErrNotExist. See Issues 18984 and 49552.
+func mapOpenError(originalErr error, name string, sep rune, stat func(string) (fs.FileInfo, error)) error {
if errors.Is(originalErr, fs.ErrNotExist) || errors.Is(originalErr, fs.ErrPermission) {
return originalErr
}
- parts := strings.Split(name, string(filepath.Separator))
+ parts := strings.Split(name, string(sep))
for i := range parts {
if parts[i] == "" {
continue
}
- fi, err := os.Stat(strings.Join(parts[:i+1], string(filepath.Separator)))
+ fi, err := stat(strings.Join(parts[:i+1], string(sep)))
if err != nil {
return originalErr
}
@@ -79,7 +79,7 @@ func (d Dir) Open(name string) (File, error) {
fullName := filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name)))
f, err := os.Open(fullName)
if err != nil {
- return nil, mapDirOpenError(err, fullName)
+ return nil, mapOpenError(err, fullName, filepath.Separator, os.Stat)
}
return f, nil
}
@@ -759,7 +759,9 @@ func (f ioFS) Open(name string) (File, error) {
}
file, err := f.fsys.Open(name)
if err != nil {
- return nil, err
+ return nil, mapOpenError(err, name, '/', func(path string) (fs.FileInfo, error) {
+ return fs.Stat(f.fsys, path)
+ })
}
return ioFile{file}, nil
}
@@ -881,11 +883,11 @@ func parseRange(s string, size int64) ([]httpRange, error) {
if ra == "" {
continue
}
- i := strings.Index(ra, "-")
- if i < 0 {
+ start, end, ok := strings.Cut(ra, "-")
+ if !ok {
return nil, errors.New("invalid range")
}
- start, end := textproto.TrimString(ra[:i]), textproto.TrimString(ra[i+1:])
+ start, end = textproto.TrimString(start), textproto.TrimString(end)
var r httpRange
if start == "" {
// If no start is specified, end specifies the
diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go
index b42ade1..d627dfd 100644
--- a/libgo/go/net/http/fs_test.go
+++ b/libgo/go/net/http/fs_test.go
@@ -658,7 +658,7 @@ type fakeFileInfo struct {
}
func (f *fakeFileInfo) Name() string { return f.basename }
-func (f *fakeFileInfo) Sys() interface{} { return nil }
+func (f *fakeFileInfo) Sys() any { return nil }
func (f *fakeFileInfo) ModTime() time.Time { return f.modtime }
func (f *fakeFileInfo) IsDir() bool { return f.dir }
func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) }
@@ -1244,10 +1244,19 @@ func TestLinuxSendfileChild(*testing.T) {
}
}
-// Issue 18984: tests that requests for paths beyond files return not-found errors
+// Issues 18984, 49552: tests that requests for paths beyond files return not-found errors
func TestFileServerNotDirError(t *testing.T) {
defer afterTest(t)
- ts := httptest.NewServer(FileServer(Dir("testdata")))
+ t.Run("Dir", func(t *testing.T) {
+ testFileServerNotDirError(t, func(path string) FileSystem { return Dir(path) })
+ })
+ t.Run("FS", func(t *testing.T) {
+ testFileServerNotDirError(t, func(path string) FileSystem { return FS(os.DirFS(path)) })
+ })
+}
+
+func testFileServerNotDirError(t *testing.T, newfs func(string) FileSystem) {
+ ts := httptest.NewServer(FileServer(newfs("testdata")))
defer ts.Close()
res, err := Get(ts.URL + "/index.html/not-a-file")
@@ -1259,9 +1268,9 @@ func TestFileServerNotDirError(t *testing.T) {
t.Errorf("StatusCode = %v; want 404", res.StatusCode)
}
- test := func(name string, dir Dir) {
+ test := func(name string, fsys FileSystem) {
t.Run(name, func(t *testing.T) {
- _, err = dir.Open("/index.html/not-a-file")
+ _, err = fsys.Open("/index.html/not-a-file")
if err == nil {
t.Fatal("err == nil; want != nil")
}
@@ -1270,7 +1279,7 @@ func TestFileServerNotDirError(t *testing.T) {
errors.Is(err, fs.ErrNotExist))
}
- _, err = dir.Open("/index.html/not-a-dir/not-a-file")
+ _, err = fsys.Open("/index.html/not-a-dir/not-a-file")
if err == nil {
t.Fatal("err == nil; want != nil")
}
@@ -1286,8 +1295,8 @@ func TestFileServerNotDirError(t *testing.T) {
t.Fatal("get abs path:", err)
}
- test("RelativePath", Dir("testdata"))
- test("AbsolutePath", Dir(absPath))
+ test("RelativePath", newfs("testdata"))
+ test("AbsolutePath", newfs(absPath))
}
func TestFileServerCleanPath(t *testing.T) {
diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go
index 8958a9e..bb82f24 100644
--- a/libgo/go/net/http/h2_bundle.go
+++ b/libgo/go/net/http/h2_bundle.go
@@ -53,6 +53,10 @@ import (
"golang.org/x/net/idna"
)
+// The HTTP protocols are defined in terms of ASCII, not Unicode. This file
+// contains helper functions which may use Unicode-aware functions which would
+// otherwise be unsafe and could introduce vulnerabilities if used improperly.
+
// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
// are equal, ASCII-case-insensitively.
func http2asciiEqualFold(s, t string) bool {
@@ -733,6 +737,12 @@ func http2isBadCipher(cipher uint16) bool {
// ClientConnPool manages a pool of HTTP/2 client connections.
type http2ClientConnPool interface {
+ // GetClientConn returns a specific HTTP/2 connection (usually
+ // a TLS-TCP connection) to an HTTP/2 server. On success, the
+ // returned ClientConn accounts for the upcoming RoundTrip
+ // call, so the caller should not omit it. If the caller needs
+ // to, ClientConn.RoundTrip can be called with a bogus
+ // new(http.Request) to release the stream reservation.
GetClientConn(req *Request, addr string) (*http2ClientConn, error)
MarkDead(*http2ClientConn)
}
@@ -759,7 +769,7 @@ type http2clientConnPool struct {
conns map[string][]*http2ClientConn // key is host:port
dialing map[string]*http2dialCall // currently in-flight dials
keys map[*http2ClientConn][]string
- addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeede calls
+ addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeeded calls
}
func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
@@ -771,28 +781,8 @@ const (
http2noDialOnMiss = false
)
-// shouldTraceGetConn reports whether getClientConn should call any
-// ClientTrace.GetConn hook associated with the http.Request.
-//
-// This complexity is needed to avoid double calls of the GetConn hook
-// during the back-and-forth between net/http and x/net/http2 (when the
-// net/http.Transport is upgraded to also speak http2), as well as support
-// the case where x/net/http2 is being used directly.
-func (p *http2clientConnPool) shouldTraceGetConn(st http2clientConnIdleState) bool {
- // If our Transport wasn't made via ConfigureTransport, always
- // trace the GetConn hook if provided, because that means the
- // http2 package is being used directly and it's the one
- // dialing, as opposed to net/http.
- if _, ok := p.t.ConnPool.(http2noDialClientConnPool); !ok {
- return true
- }
- // Otherwise, only use the GetConn hook if this connection has
- // been used previously for other requests. For fresh
- // connections, the net/http package does the dialing.
- return !st.freshConn
-}
-
func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
+ // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set?
if http2isConnectionCloseRequest(req) && dialOnMiss {
// It gets its own connection.
http2traceGetConn(req, addr)
@@ -806,10 +796,14 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis
for {
p.mu.Lock()
for _, cc := range p.conns[addr] {
- if st := cc.idleState(); st.canTakeNewRequest {
- if p.shouldTraceGetConn(st) {
+ if cc.ReserveNewRequest() {
+ // When a connection is presented to us by the net/http package,
+ // the GetConn hook has already been called.
+ // Don't call it a second time here.
+ if !cc.getConnCalled {
http2traceGetConn(req, addr)
}
+ cc.getConnCalled = false
p.mu.Unlock()
return cc, nil
}
@@ -825,7 +819,13 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis
if http2shouldRetryDial(call, req) {
continue
}
- return call.res, call.err
+ cc, err := call.res, call.err
+ if err != nil {
+ return nil, err
+ }
+ if cc.ReserveNewRequest() {
+ return cc, nil
+ }
}
}
@@ -922,6 +922,7 @@ func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) {
if err != nil {
c.err = err
} else {
+ cc.getConnCalled = true // already called by the net/http package
p.addConnLocked(key, cc)
}
delete(p.addConnCalls, key)
@@ -1208,6 +1209,13 @@ func (e http2ErrCode) String() string {
return fmt.Sprintf("unknown error code 0x%x", uint32(e))
}
+func (e http2ErrCode) stringToken() string {
+ if s, ok := http2errCodeName[e]; ok {
+ return s
+ }
+ return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e))
+}
+
// ConnectionError is an error that results in the termination of the
// entire connection.
type http2ConnectionError http2ErrCode
@@ -1224,6 +1232,11 @@ type http2StreamError struct {
Cause error // optional additional detail
}
+// errFromPeer is a sentinel error value for StreamError.Cause to
+// indicate that the StreamError was sent from the peer over the wire
+// and wasn't locally generated in the Transport.
+var http2errFromPeer = errors.New("received from peer")
+
func http2streamError(id uint32, code http2ErrCode) http2StreamError {
return http2StreamError{StreamID: id, Code: code}
}
@@ -1438,7 +1451,7 @@ var http2flagName = map[http2FrameType]map[http2Flags]string{
// a frameParser parses a frame given its FrameHeader and payload
// bytes. The length of payload will always equal fh.Length (which
// might be 0).
-type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error)
+type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error)
var http2frameParsers = map[http2FrameType]http2frameParser{
http2FrameData: http2parseDataFrame,
@@ -1583,6 +1596,11 @@ type http2Framer struct {
lastFrame http2Frame
errDetail error
+ // countError is a non-nil func that's called on a frame parse
+ // error with some unique error path token. It's initialized
+ // from Transport.CountError or Server.CountError.
+ countError func(errToken string)
+
// lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION.
lastHeaderStream uint32
@@ -1745,6 +1763,7 @@ func http2NewFramer(w io.Writer, r io.Reader) *http2Framer {
fr := &http2Framer{
w: w,
r: r,
+ countError: func(string) {},
logReads: http2logFrameReads,
logWrites: http2logFrameWrites,
debugReadLoggerf: log.Printf,
@@ -1819,7 +1838,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) {
if _, err := io.ReadFull(fr.r, payload); err != nil {
return nil, err
}
- f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, payload)
+ f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload)
if err != nil {
if ce, ok := err.(http2connError); ok {
return nil, fr.connError(ce.Code, ce.Reason)
@@ -1907,13 +1926,14 @@ func (f *http2DataFrame) Data() []byte {
return f.data
}
-func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) {
+func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) {
if fh.StreamID == 0 {
// DATA frames MUST be associated with a stream. If a
// DATA frame is received whose stream identifier
// field is 0x0, the recipient MUST respond with a
// connection error (Section 5.4.1) of type
// PROTOCOL_ERROR.
+ countError("frame_data_stream_0")
return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"}
}
f := fc.getDataFrame()
@@ -1924,6 +1944,7 @@ func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byt
var err error
payload, padSize, err = http2readByte(payload)
if err != nil {
+ countError("frame_data_pad_byte_short")
return nil, err
}
}
@@ -1932,6 +1953,7 @@ func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byt
// length of the frame payload, the recipient MUST
// treat this as a connection error.
// Filed: https://github.com/http2/http2-spec/issues/610
+ countError("frame_data_pad_too_big")
return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"}
}
f.data = payload[:len(payload)-int(padSize)]
@@ -2014,7 +2036,7 @@ type http2SettingsFrame struct {
p []byte
}
-func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) {
+func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 {
// When this (ACK 0x1) bit is set, the payload of the
// SETTINGS frame MUST be empty. Receipt of a
@@ -2022,6 +2044,7 @@ func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte)
// field value other than 0 MUST be treated as a
// connection error (Section 5.4.1) of type
// FRAME_SIZE_ERROR.
+ countError("frame_settings_ack_with_length")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
if fh.StreamID != 0 {
@@ -2032,14 +2055,17 @@ func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte)
// field is anything other than 0x0, the endpoint MUST
// respond with a connection error (Section 5.4.1) of
// type PROTOCOL_ERROR.
+ countError("frame_settings_has_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
if len(p)%6 != 0 {
+ countError("frame_settings_mod_6")
// Expecting even number of 6 byte settings.
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
f := &http2SettingsFrame{http2FrameHeader: fh, p: p}
if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 {
+ countError("frame_settings_window_size_too_big")
// Values above the maximum flow control window size of 2^31 - 1 MUST
// be treated as a connection error (Section 5.4.1) of type
// FLOW_CONTROL_ERROR.
@@ -2151,11 +2177,13 @@ type http2PingFrame struct {
func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) }
-func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) {
+func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) {
if len(payload) != 8 {
+ countError("frame_ping_length")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
if fh.StreamID != 0 {
+ countError("frame_ping_has_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
f := &http2PingFrame{http2FrameHeader: fh}
@@ -2191,11 +2219,13 @@ func (f *http2GoAwayFrame) DebugData() []byte {
return f.debugData
}
-func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) {
+func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if fh.StreamID != 0 {
+ countError("frame_goaway_has_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
if len(p) < 8 {
+ countError("frame_goaway_short")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
return &http2GoAwayFrame{
@@ -2231,7 +2261,7 @@ func (f *http2UnknownFrame) Payload() []byte {
return f.p
}
-func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) {
+func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
return &http2UnknownFrame{fh, p}, nil
}
@@ -2242,8 +2272,9 @@ type http2WindowUpdateFrame struct {
Increment uint32 // never read with high bit set
}
-func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) {
+func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if len(p) != 4 {
+ countError("frame_windowupdate_bad_len")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit
@@ -2255,8 +2286,10 @@ func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, p []by
// control window MUST be treated as a connection
// error (Section 5.4.1).
if fh.StreamID == 0 {
+ countError("frame_windowupdate_zero_inc_conn")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
+ countError("frame_windowupdate_zero_inc_stream")
return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol)
}
return &http2WindowUpdateFrame{
@@ -2307,7 +2340,7 @@ func (f *http2HeadersFrame) HasPriority() bool {
return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority)
}
-func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) {
+func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) {
hf := &http2HeadersFrame{
http2FrameHeader: fh,
}
@@ -2316,11 +2349,13 @@ func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (
// is received whose stream identifier field is 0x0, the recipient MUST
// respond with a connection error (Section 5.4.1) of type
// PROTOCOL_ERROR.
+ countError("frame_headers_zero_stream")
return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"}
}
var padLength uint8
if fh.Flags.Has(http2FlagHeadersPadded) {
if p, padLength, err = http2readByte(p); err != nil {
+ countError("frame_headers_pad_short")
return
}
}
@@ -2328,16 +2363,19 @@ func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (
var v uint32
p, v, err = http2readUint32(p)
if err != nil {
+ countError("frame_headers_prio_short")
return nil, err
}
hf.Priority.StreamDep = v & 0x7fffffff
hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set
p, hf.Priority.Weight, err = http2readByte(p)
if err != nil {
+ countError("frame_headers_prio_weight_short")
return nil, err
}
}
- if len(p)-int(padLength) <= 0 {
+ if len(p)-int(padLength) < 0 {
+ countError("frame_headers_pad_too_big")
return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol)
}
hf.headerFragBuf = p[:len(p)-int(padLength)]
@@ -2444,11 +2482,13 @@ func (p http2PriorityParam) IsZero() bool {
return p == http2PriorityParam{}
}
-func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) {
+func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) {
if fh.StreamID == 0 {
+ countError("frame_priority_zero_stream")
return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"}
}
if len(payload) != 5 {
+ countError("frame_priority_bad_length")
return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))}
}
v := binary.BigEndian.Uint32(payload[:4])
@@ -2491,11 +2531,13 @@ type http2RSTStreamFrame struct {
ErrCode http2ErrCode
}
-func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) {
+func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if len(p) != 4 {
+ countError("frame_rststream_bad_len")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
if fh.StreamID == 0 {
+ countError("frame_rststream_zero_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil
@@ -2521,8 +2563,9 @@ type http2ContinuationFrame struct {
headerFragBuf []byte
}
-func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) {
+func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if fh.StreamID == 0 {
+ countError("frame_continuation_zero_stream")
return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"}
}
return &http2ContinuationFrame{fh, p}, nil
@@ -2571,7 +2614,7 @@ func (f *http2PushPromiseFrame) HeadersEnded() bool {
return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders)
}
-func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) {
+func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) {
pp := &http2PushPromiseFrame{
http2FrameHeader: fh,
}
@@ -2582,6 +2625,7 @@ func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_
// with. If the stream identifier field specifies the value
// 0x0, a recipient MUST respond with a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR.
+ countError("frame_pushpromise_zero_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
// The PUSH_PROMISE frame includes optional padding.
@@ -2589,18 +2633,21 @@ func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_
var padLength uint8
if fh.Flags.Has(http2FlagPushPromisePadded) {
if p, padLength, err = http2readByte(p); err != nil {
+ countError("frame_pushpromise_pad_short")
return
}
}
p, pp.PromiseID, err = http2readUint32(p)
if err != nil {
+ countError("frame_pushpromise_promiseid_short")
return
}
pp.PromiseID = pp.PromiseID & (1<<31 - 1)
if int(padLength) > len(p) {
// like the DATA frame, error out if padding is longer than the body.
+ countError("frame_pushpromise_pad_too_big")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
pp.headerFragBuf = p[:len(p)-int(padLength)]
@@ -3570,6 +3617,17 @@ type http2pipeBuffer interface {
io.Reader
}
+// setBuffer initializes the pipe buffer.
+// It has no effect if the pipe is already closed.
+func (p *http2pipe) setBuffer(b http2pipeBuffer) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.err != nil || p.breakErr != nil {
+ return
+ }
+ p.b = b
+}
+
func (p *http2pipe) Len() int {
p.mu.Lock()
defer p.mu.Unlock()
@@ -3786,6 +3844,12 @@ type http2Server struct {
// If nil, a default scheduler is chosen.
NewWriteScheduler func() http2WriteScheduler
+ // CountError, if non-nil, is called on HTTP/2 server errors.
+ // It's intended to increment a metric for monitoring, such
+ // as an expvar or Prometheus metric.
+ // The errType consists of only ASCII word characters.
+ CountError func(errType string)
+
// Internal state. This is a pointer (rather than embedded directly)
// so that we don't embed a Mutex in this struct, which will make the
// struct non-copyable, which might break some callers.
@@ -3915,16 +3979,12 @@ func http2ConfigureServer(s *Server, conf *http2Server) error {
s.TLSConfig.PreferServerCipherSuites = true
- haveNPN := false
- for _, p := range s.TLSConfig.NextProtos {
- if p == http2NextProtoTLS {
- haveNPN = true
- break
- }
- }
- if !haveNPN {
+ if !http2strSliceContains(s.TLSConfig.NextProtos, http2NextProtoTLS) {
s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS)
}
+ if !http2strSliceContains(s.TLSConfig.NextProtos, "http/1.1") {
+ s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1")
+ }
if s.TLSNextProto == nil {
s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
@@ -4065,6 +4125,9 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
fr := http2NewFramer(sc.bw, c)
+ if s.CountError != nil {
+ fr.countError = s.CountError
+ }
fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize())
@@ -4373,7 +4436,15 @@ func (sc *http2serverConn) canonicalHeader(v string) string {
sc.canonHeader = make(map[string]string)
}
cv = CanonicalHeaderKey(v)
- sc.canonHeader[v] = cv
+ // maxCachedCanonicalHeaders is an arbitrarily-chosen limit on the number of
+ // entries in the canonHeader cache. This should be larger than the number
+ // of unique, uncommon header keys likely to be sent by the peer, while not
+ // so high as to permit unreaasonable memory usage if the peer sends an unbounded
+ // number of unique header keys.
+ const maxCachedCanonicalHeaders = 32
+ if len(sc.canonHeader) < maxCachedCanonicalHeaders {
+ sc.canonHeader[v] = cv
+ }
return cv
}
@@ -4479,7 +4550,7 @@ func (sc *http2serverConn) serve() {
})
sc.unackedSettings++
- // Each connection starts with intialWindowSize inflow tokens.
+ // Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens.
if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff))
@@ -5064,7 +5135,7 @@ func (sc *http2serverConn) processFrame(f http2Frame) error {
// First frame received must be SETTINGS.
if !sc.sawFirstSettings {
if _, ok := f.(*http2SettingsFrame); !ok {
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("first_settings", http2ConnectionError(http2ErrCodeProtocol))
}
sc.sawFirstSettings = true
}
@@ -5089,7 +5160,7 @@ func (sc *http2serverConn) processFrame(f http2Frame) error {
case *http2PushPromiseFrame:
// A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE
// frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("push_promise", http2ConnectionError(http2ErrCodeProtocol))
default:
sc.vlogf("http2: server ignoring frame: %v", f.Header())
return nil
@@ -5109,7 +5180,7 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error {
// identifier field value other than 0x0, the recipient MUST
// respond with a connection error (Section 5.4.1) of type
// PROTOCOL_ERROR."
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol))
}
if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo {
return nil
@@ -5128,7 +5199,7 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error
// or PRIORITY on a stream in this state MUST be
// treated as a connection error (Section 5.4.1) of
// type PROTOCOL_ERROR."
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("stream_idle", http2ConnectionError(http2ErrCodeProtocol))
}
if st == nil {
// "WINDOW_UPDATE can be sent by a peer that has sent a
@@ -5139,7 +5210,7 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error
return nil
}
if !st.flow.add(int32(f.Increment)) {
- return http2streamError(f.StreamID, http2ErrCodeFlowControl)
+ return sc.countError("bad_flow", http2streamError(f.StreamID, http2ErrCodeFlowControl))
}
default: // connection-level flow control
if !sc.flow.add(int32(f.Increment)) {
@@ -5160,7 +5231,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error {
// identifying an idle stream is received, the
// recipient MUST treat this as a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR.
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("reset_idle_stream", http2ConnectionError(http2ErrCodeProtocol))
}
if st != nil {
st.cancelCtx()
@@ -5212,7 +5283,7 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error {
// Why is the peer ACKing settings we never sent?
// The spec doesn't mention this case, but
// hang up on them anyway.
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("ack_mystery", http2ConnectionError(http2ErrCodeProtocol))
}
return nil
}
@@ -5220,7 +5291,7 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error {
// This isn't actually in the spec, but hang up on
// suspiciously large settings frames or those with
// duplicate entries.
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("settings_big_or_dups", http2ConnectionError(http2ErrCodeProtocol))
}
if err := f.ForeachSetting(sc.processSetting); err != nil {
return err
@@ -5287,7 +5358,7 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error {
// control window to exceed the maximum size as a
// connection error (Section 5.4.1) of type
// FLOW_CONTROL_ERROR."
- return http2ConnectionError(http2ErrCodeFlowControl)
+ return sc.countError("setting_win_size", http2ConnectionError(http2ErrCodeFlowControl))
}
}
return nil
@@ -5320,7 +5391,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error {
// or PRIORITY on a stream in this state MUST be
// treated as a connection error (Section 5.4.1) of
// type PROTOCOL_ERROR."
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("data_on_idle", http2ConnectionError(http2ErrCodeProtocol))
}
// "If a DATA frame is received whose stream is not in "open"
@@ -5337,7 +5408,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error {
// and return any flow control bytes since we're not going
// to consume them.
if sc.inflow.available() < int32(f.Length) {
- return http2streamError(id, http2ErrCodeFlowControl)
+ return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl))
}
// Deduct the flow control from inflow, since we're
// going to immediately add it back in
@@ -5350,7 +5421,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error {
// Already have a stream error in flight. Don't send another.
return nil
}
- return http2streamError(id, http2ErrCodeStreamClosed)
+ return sc.countError("closed", http2streamError(id, http2ErrCodeStreamClosed))
}
if st.body == nil {
panic("internal error: should have a body in this state")
@@ -5362,12 +5433,12 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error {
// RFC 7540, sec 8.1.2.6: A request or response is also malformed if the
// value of a content-length header field does not equal the sum of the
// DATA frame payload lengths that form the body.
- return http2streamError(id, http2ErrCodeProtocol)
+ return sc.countError("send_too_much", http2streamError(id, http2ErrCodeProtocol))
}
if f.Length > 0 {
// Check whether the client has flow control quota.
if st.inflow.available() < int32(f.Length) {
- return http2streamError(id, http2ErrCodeFlowControl)
+ return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl))
}
st.inflow.take(int32(f.Length))
@@ -5375,7 +5446,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error {
wrote, err := st.body.Write(data)
if err != nil {
sc.sendWindowUpdate(nil, int(f.Length)-wrote)
- return http2streamError(id, http2ErrCodeStreamClosed)
+ return sc.countError("body_write_err", http2streamError(id, http2ErrCodeStreamClosed))
}
if wrote != len(data) {
panic("internal error: bad Writer")
@@ -5461,7 +5532,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
// stream identifier MUST respond with a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR.
if id%2 != 1 {
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("headers_even", http2ConnectionError(http2ErrCodeProtocol))
}
// A HEADERS frame can be used to create a new stream or
// send a trailer for an open one. If we already have a stream
@@ -5478,7 +5549,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
// this state, it MUST respond with a stream error (Section 5.4.2) of
// type STREAM_CLOSED.
if st.state == http2stateHalfClosedRemote {
- return http2streamError(id, http2ErrCodeStreamClosed)
+ return sc.countError("headers_half_closed", http2streamError(id, http2ErrCodeStreamClosed))
}
return st.processTrailerHeaders(f)
}
@@ -5489,7 +5560,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
// receives an unexpected stream identifier MUST respond with
// a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
if id <= sc.maxClientStreamID {
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("stream_went_down", http2ConnectionError(http2ErrCodeProtocol))
}
sc.maxClientStreamID = id
@@ -5506,14 +5577,14 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
if sc.curClientStreams+1 > sc.advMaxStreams {
if sc.unackedSettings == 0 {
// They should know better.
- return http2streamError(id, http2ErrCodeProtocol)
+ return sc.countError("over_max_streams", http2streamError(id, http2ErrCodeProtocol))
}
// Assume it's a network race, where they just haven't
// received our last SETTINGS update. But actually
// this can't happen yet, because we don't yet provide
// a way for users to adjust server parameters at
// runtime.
- return http2streamError(id, http2ErrCodeRefusedStream)
+ return sc.countError("over_max_streams_race", http2streamError(id, http2ErrCodeRefusedStream))
}
initialState := http2stateOpen
@@ -5523,7 +5594,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
st := sc.newStream(id, 0, initialState)
if f.HasPriority() {
- if err := http2checkPriority(f.StreamID, f.Priority); err != nil {
+ if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
return err
}
sc.writeSched.AdjustStream(st.id, f.Priority)
@@ -5567,15 +5638,15 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error {
sc := st.sc
sc.serveG.check()
if st.gotTrailerHeader {
- return http2ConnectionError(http2ErrCodeProtocol)
+ return sc.countError("dup_trailers", http2ConnectionError(http2ErrCodeProtocol))
}
st.gotTrailerHeader = true
if !f.StreamEnded() {
- return http2streamError(st.id, http2ErrCodeProtocol)
+ return sc.countError("trailers_not_ended", http2streamError(st.id, http2ErrCodeProtocol))
}
if len(f.PseudoFields()) > 0 {
- return http2streamError(st.id, http2ErrCodeProtocol)
+ return sc.countError("trailers_pseudo", http2streamError(st.id, http2ErrCodeProtocol))
}
if st.trailer != nil {
for _, hf := range f.RegularFields() {
@@ -5584,7 +5655,7 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error {
// TODO: send more details to the peer somehow. But http2 has
// no way to send debug data at a stream level. Discuss with
// HTTP folk.
- return http2streamError(st.id, http2ErrCodeProtocol)
+ return sc.countError("trailers_bogus", http2streamError(st.id, http2ErrCodeProtocol))
}
st.trailer[key] = append(st.trailer[key], hf.Value)
}
@@ -5593,13 +5664,13 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error {
return nil
}
-func http2checkPriority(streamID uint32, p http2PriorityParam) error {
+func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) error {
if streamID == p.StreamDep {
// Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat
// this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR."
// Section 5.3.3 says that a stream can depend on one of its dependencies,
// so it's only self-dependencies that are forbidden.
- return http2streamError(streamID, http2ErrCodeProtocol)
+ return sc.countError("priority", http2streamError(streamID, http2ErrCodeProtocol))
}
return nil
}
@@ -5608,7 +5679,7 @@ func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error {
if sc.inGoAway {
return nil
}
- if err := http2checkPriority(f.StreamID, f.http2PriorityParam); err != nil {
+ if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil {
return err
}
sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam)
@@ -5665,7 +5736,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead
isConnect := rp.method == "CONNECT"
if isConnect {
if rp.path != "" || rp.scheme != "" || rp.authority == "" {
- return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
+ return nil, nil, sc.countError("bad_connect", http2streamError(f.StreamID, http2ErrCodeProtocol))
}
} else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses:
@@ -5678,13 +5749,13 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead
// "All HTTP/2 requests MUST include exactly one valid
// value for the :method, :scheme, and :path
// pseudo-header fields"
- return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
+ return nil, nil, sc.countError("bad_path_method", http2streamError(f.StreamID, http2ErrCodeProtocol))
}
bodyOpen := !f.StreamEnded()
if rp.method == "HEAD" && bodyOpen {
// HEAD requests can't have bodies
- return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
+ return nil, nil, sc.countError("head_body", http2streamError(f.StreamID, http2ErrCodeProtocol))
}
rp.header = make(Header)
@@ -5767,7 +5838,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re
var err error
url_, err = url.ParseRequestURI(rp.path)
if err != nil {
- return nil, nil, http2streamError(st.id, http2ErrCodeProtocol)
+ return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol))
}
requestURI = rp.path
}
@@ -6651,6 +6722,34 @@ func http2h1ServerKeepAlivesDisabled(hs *Server) bool {
return false
}
+func (sc *http2serverConn) countError(name string, err error) error {
+ if sc == nil || sc.srv == nil {
+ return err
+ }
+ f := sc.srv.CountError
+ if f == nil {
+ return err
+ }
+ var typ string
+ var code http2ErrCode
+ switch e := err.(type) {
+ case http2ConnectionError:
+ typ = "conn"
+ code = http2ErrCode(e)
+ case http2StreamError:
+ typ = "stream"
+ code = http2ErrCode(e.Code)
+ default:
+ return err
+ }
+ codeStr := http2errCodeName[code]
+ if codeStr == "" {
+ codeStr = strconv.Itoa(int(code))
+ }
+ f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name))
+ return err
+}
+
const (
// transportDefaultConnFlow is how many connection-level flow control
// tokens we give the server at start-up, past the default 64k.
@@ -6666,6 +6765,15 @@ const (
http2transportDefaultStreamMinRefresh = 4 << 10
http2defaultUserAgent = "Go-http-client/2.0"
+
+ // initialMaxConcurrentStreams is a connections maxConcurrentStreams until
+ // it's received servers initial SETTINGS frame, which corresponds with the
+ // spec's minimum recommended value.
+ http2initialMaxConcurrentStreams = 100
+
+ // defaultMaxConcurrentStreams is a connections default maxConcurrentStreams
+ // if the server doesn't include one in its initial SETTINGS frame.
+ http2defaultMaxConcurrentStreams = 1000
)
// Transport is an HTTP/2 Transport.
@@ -6736,6 +6844,17 @@ type http2Transport struct {
// Defaults to 15s.
PingTimeout time.Duration
+ // WriteByteTimeout is the timeout after which the connection will be
+ // closed no data can be written to it. The timeout begins when data is
+ // available to write, and is extended whenever any bytes are written.
+ WriteByteTimeout time.Duration
+
+ // CountError, if non-nil, is called on HTTP/2 transport errors.
+ // It's intended to increment a metric for monitoring, such
+ // as an expvar or Prometheus metric.
+ // The errType consists of only ASCII word characters.
+ CountError func(errType string)
+
// t1, if non-nil, is the standard library Transport using
// this transport. Its settings are used (but not its
// RoundTrip method, etc).
@@ -6842,11 +6961,12 @@ func (t *http2Transport) initConnPool() {
// ClientConn is the state of a single HTTP/2 client connection to an
// HTTP/2 server.
type http2ClientConn struct {
- t *http2Transport
- tconn net.Conn // usually *tls.Conn, except specialized impls
- tlsState *tls.ConnectionState // nil only for specialized impls
- reused uint32 // whether conn is being reused; atomic
- singleUse bool // whether being used for a single http.Request
+ t *http2Transport
+ tconn net.Conn // usually *tls.Conn, except specialized impls
+ tlsState *tls.ConnectionState // nil only for specialized impls
+ reused uint32 // whether conn is being reused; atomic
+ singleUse bool // whether being used for a single http.Request
+ getConnCalled bool // used by clientConnPool
// readLoop goroutine fields:
readerDone chan struct{} // closed on error
@@ -6859,87 +6979,94 @@ type http2ClientConn struct {
cond *sync.Cond // hold mu; broadcast on flow/closed changes
flow http2flow // our conn-level flow control quota (cs.flow is per stream)
inflow http2flow // peer's conn-level flow control
+ doNotReuse bool // whether conn is marked to not be reused for any future requests
closing bool
closed bool
+ seenSettings bool // true if we've seen a settings frame, false otherwise
wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received
goAwayDebug string // goAway frame's debug data, retained as a string
streams map[uint32]*http2clientStream // client-initiated
+ streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
nextStreamID uint32
pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
pings map[[8]byte]chan struct{} // in flight ping data to notification channel
- bw *bufio.Writer
br *bufio.Reader
- fr *http2Framer
lastActive time.Time
lastIdle time.Time // time last idle
- // Settings from peer: (also guarded by mu)
+ // Settings from peer: (also guarded by wmu)
maxFrameSize uint32
maxConcurrentStreams uint32
peerMaxHeaderListSize uint64
initialWindowSize uint32
- hbuf bytes.Buffer // HPACK encoder writes into this
- henc *hpack.Encoder
- freeBuf [][]byte
+ // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
+ // Write to reqHeaderMu to lock it, read from it to unlock.
+ // Lock reqmu BEFORE mu or wmu.
+ reqHeaderMu chan struct{}
- wmu sync.Mutex // held while writing; acquire AFTER mu if holding both
- werr error // first write error that has occurred
+ // wmu is held while writing.
+ // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes.
+ // Only acquire both at the same time when changing peer settings.
+ wmu sync.Mutex
+ bw *bufio.Writer
+ fr *http2Framer
+ werr error // first write error that has occurred
+ hbuf bytes.Buffer // HPACK encoder writes into this
+ henc *hpack.Encoder
}
// clientStream is the state for a single HTTP/2 stream. One of these
// is created for each Transport.RoundTrip call.
type http2clientStream struct {
- cc *http2ClientConn
- req *Request
+ cc *http2ClientConn
+
+ // Fields of Request that we may access even after the response body is closed.
+ ctx context.Context
+ reqCancel <-chan struct{}
+
trace *httptrace.ClientTrace // or nil
ID uint32
- resc chan http2resAndError
bufPipe http2pipe // buffered pipe with the flow-controlled response payload
- startedWrite bool // started request body write; guarded by cc.mu
requestedGzip bool
- on100 func() // optional code to run if get a 100 continue response
+ isHead bool
+
+ abortOnce sync.Once
+ abort chan struct{} // closed to signal stream should end immediately
+ abortErr error // set if abort is closed
+
+ peerClosed chan struct{} // closed when the peer sends an END_STREAM flag
+ donec chan struct{} // closed after the stream is in the closed state
+ on100 chan struct{} // buffered; written to if a 100 is received
+
+ respHeaderRecv chan struct{} // closed when headers are received
+ res *Response // set if respHeaderRecv is closed
flow http2flow // guarded by cc.mu
inflow http2flow // guarded by cc.mu
bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
readErr error // sticky read error; owned by transportResponseBody.Read
- stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu
- didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu
- peerReset chan struct{} // closed on peer reset
- resetErr error // populated before peerReset is closed
+ reqBody io.ReadCloser
+ reqBodyContentLength int64 // -1 means unknown
+ reqBodyClosed bool // body has been closed; guarded by cc.mu
- done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu
+ // owned by writeRequest:
+ sentEndStream bool // sent an END_STREAM flag to the peer
+ sentHeaders bool
// owned by clientConnReadLoop:
firstByte bool // got the first response byte
pastHeaders bool // got first MetaHeadersFrame (actual headers)
pastTrailers bool // got optional second MetaHeadersFrame (trailers)
num1xx uint8 // number of 1xx responses seen
+ readClosed bool // peer sent an END_STREAM flag
+ readAborted bool // read loop reset the stream
trailer Header // accumulated trailers
resTrailer *Header // client's Response.Trailer
}
-// awaitRequestCancel waits for the user to cancel a request or for the done
-// channel to be signaled. A non-nil error is returned only if the request was
-// canceled.
-func http2awaitRequestCancel(req *Request, done <-chan struct{}) error {
- ctx := req.Context()
- if req.Cancel == nil && ctx.Done() == nil {
- return nil
- }
- select {
- case <-req.Cancel:
- return http2errRequestCanceled
- case <-ctx.Done():
- return ctx.Err()
- case <-done:
- return nil
- }
-}
-
var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error
// get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func,
@@ -6951,73 +7078,65 @@ func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) e
return http2traceGot1xxResponseFunc(cs.trace)
}
-// awaitRequestCancel waits for the user to cancel a request, its context to
-// expire, or for the request to be done (any way it might be removed from the
-// cc.streams map: peer reset, successful completion, TCP connection breakage,
-// etc). If the request is canceled, then cs will be canceled and closed.
-func (cs *http2clientStream) awaitRequestCancel(req *Request) {
- if err := http2awaitRequestCancel(req, cs.done); err != nil {
- cs.cancelStream()
- cs.bufPipe.CloseWithError(err)
- }
+func (cs *http2clientStream) abortStream(err error) {
+ cs.cc.mu.Lock()
+ defer cs.cc.mu.Unlock()
+ cs.abortStreamLocked(err)
}
-func (cs *http2clientStream) cancelStream() {
- cc := cs.cc
- cc.mu.Lock()
- didReset := cs.didReset
- cs.didReset = true
- cc.mu.Unlock()
-
- if !didReset {
- cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
- cc.forgetStreamID(cs.ID)
+func (cs *http2clientStream) abortStreamLocked(err error) {
+ cs.abortOnce.Do(func() {
+ cs.abortErr = err
+ close(cs.abort)
+ })
+ if cs.reqBody != nil && !cs.reqBodyClosed {
+ cs.reqBody.Close()
+ cs.reqBodyClosed = true
}
-}
-
-// checkResetOrDone reports any error sent in a RST_STREAM frame by the
-// server, or errStreamClosed if the stream is complete.
-func (cs *http2clientStream) checkResetOrDone() error {
- select {
- case <-cs.peerReset:
- return cs.resetErr
- case <-cs.done:
- return http2errStreamClosed
- default:
- return nil
+ // TODO(dneil): Clean up tests where cs.cc.cond is nil.
+ if cs.cc.cond != nil {
+ // Wake up writeRequestBody if it is waiting on flow control.
+ cs.cc.cond.Broadcast()
}
}
-func (cs *http2clientStream) getStartedWrite() bool {
+func (cs *http2clientStream) abortRequestBodyWrite() {
cc := cs.cc
cc.mu.Lock()
defer cc.mu.Unlock()
- return cs.startedWrite
-}
-
-func (cs *http2clientStream) abortRequestBodyWrite(err error) {
- if err == nil {
- panic("nil error")
+ if cs.reqBody != nil && !cs.reqBodyClosed {
+ cs.reqBody.Close()
+ cs.reqBodyClosed = true
+ cc.cond.Broadcast()
}
- cc := cs.cc
- cc.mu.Lock()
- cs.stopReqBody = err
- cc.cond.Broadcast()
- cc.mu.Unlock()
}
type http2stickyErrWriter struct {
- w io.Writer
- err *error
+ conn net.Conn
+ timeout time.Duration
+ err *error
}
func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
- n, err = sew.w.Write(p)
- *sew.err = err
- return
+ for {
+ if sew.timeout != 0 {
+ sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
+ }
+ nn, err := sew.conn.Write(p[n:])
+ n += nn
+ if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
+ // Keep extending the deadline so long as we're making progress.
+ continue
+ }
+ if sew.timeout != 0 {
+ sew.conn.SetWriteDeadline(time.Time{})
+ }
+ *sew.err = err
+ return n, err
+ }
}
// noCachedConnError is the concrete type of ErrNoCachedConn, which
@@ -7091,9 +7210,9 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res
}
reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1)
http2traceGotConn(req, cc, reused)
- res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req)
+ res, err := cc.RoundTrip(req)
if err != nil && retry <= 6 {
- if req, err = http2shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil {
+ if req, err = http2shouldRetryRequest(req, err); err == nil {
// After the first retry, do exponential backoff with 10% jitter.
if retry == 0 {
continue
@@ -7104,7 +7223,7 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res
case <-time.After(time.Second * time.Duration(backoff)):
continue
case <-req.Context().Done():
- return nil, req.Context().Err()
+ err = req.Context().Err()
}
}
}
@@ -7135,7 +7254,7 @@ var (
// response headers. It is always called with a non-nil error.
// It returns either a request to retry (either the same request, or a
// modified clone), or an error if the request can't be replayed.
-func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Request, error) {
+func http2shouldRetryRequest(req *Request, err error) (*Request, error) {
if !http2canRetryError(err) {
return nil, err
}
@@ -7148,7 +7267,6 @@ func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Req
// If the request body can be reset back to its original
// state via the optional req.GetBody, do that.
if req.GetBody != nil {
- // TODO: consider a req.Body.Close here? or audit that all caller paths do?
body, err := req.GetBody()
if err != nil {
return nil, err
@@ -7160,10 +7278,8 @@ func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Req
// The Request.Body can't reset back to the beginning, but we
// don't seem to have started to read from it yet, so reuse
- // the request directly. The "afterBodyWrite" means the
- // bodyWrite process has started, which becomes true before
- // the first Read.
- if !afterBodyWrite {
+ // the request directly.
+ if err == http2errClientConnUnusable {
return req, nil
}
@@ -7175,6 +7291,10 @@ func http2canRetryError(err error) bool {
return true
}
if se, ok := err.(http2StreamError); ok {
+ if se.Code == http2ErrCodeProtocol && se.Cause == http2errFromPeer {
+ // See golang/go#47635, golang/go#42777
+ return true
+ }
return se.Code == http2ErrCodeRefusedStream
}
return false
@@ -7249,14 +7369,15 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client
tconn: c,
readerDone: make(chan struct{}),
nextStreamID: 1,
- maxFrameSize: 16 << 10, // spec default
- initialWindowSize: 65535, // spec default
- maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
- peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
+ maxFrameSize: 16 << 10, // spec default
+ initialWindowSize: 65535, // spec default
+ maxConcurrentStreams: http2initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings.
+ peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
streams: make(map[uint32]*http2clientStream),
singleUse: singleUse,
wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}),
+ reqHeaderMu: make(chan struct{}, 1),
}
if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d
@@ -7271,9 +7392,16 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client
// TODO: adjust this writer size to account for frame size +
// MTU + crypto/tls record padding.
- cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr})
+ cc.bw = bufio.NewWriter(http2stickyErrWriter{
+ conn: c,
+ timeout: t.WriteByteTimeout,
+ err: &cc.werr,
+ })
cc.br = bufio.NewReader(c)
cc.fr = http2NewFramer(cc.bw, cc.br)
+ if t.CountError != nil {
+ cc.fr.countError = t.CountError
+ }
cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil)
cc.fr.MaxHeaderListSize = t.maxHeaderListSize()
@@ -7326,6 +7454,13 @@ func (cc *http2ClientConn) healthCheck() {
}
}
+// SetDoNotReuse marks cc as not reusable for future HTTP requests.
+func (cc *http2ClientConn) SetDoNotReuse() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.doNotReuse = true
+}
+
func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
@@ -7343,27 +7478,94 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) {
last := f.LastStreamID
for streamID, cs := range cc.streams {
if streamID > last {
- select {
- case cs.resc <- http2resAndError{err: http2errClientConnGotGoAway}:
- default:
- }
+ cs.abortStreamLocked(http2errClientConnGotGoAway)
}
}
}
// CanTakeNewRequest reports whether the connection can take a new request,
// meaning it has not been closed or received or sent a GOAWAY.
+//
+// If the caller is going to immediately make a new request on this
+// connection, use ReserveNewRequest instead.
func (cc *http2ClientConn) CanTakeNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.canTakeNewRequestLocked()
}
+// ReserveNewRequest is like CanTakeNewRequest but also reserves a
+// concurrent stream in cc. The reservation is decremented on the
+// next call to RoundTrip.
+func (cc *http2ClientConn) ReserveNewRequest() bool {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if st := cc.idleStateLocked(); !st.canTakeNewRequest {
+ return false
+ }
+ cc.streamsReserved++
+ return true
+}
+
+// ClientConnState describes the state of a ClientConn.
+type http2ClientConnState struct {
+ // Closed is whether the connection is closed.
+ Closed bool
+
+ // Closing is whether the connection is in the process of
+ // closing. It may be closing due to shutdown, being a
+ // single-use connection, being marked as DoNotReuse, or
+ // having received a GOAWAY frame.
+ Closing bool
+
+ // StreamsActive is how many streams are active.
+ StreamsActive int
+
+ // StreamsReserved is how many streams have been reserved via
+ // ClientConn.ReserveNewRequest.
+ StreamsReserved int
+
+ // StreamsPending is how many requests have been sent in excess
+ // of the peer's advertised MaxConcurrentStreams setting and
+ // are waiting for other streams to complete.
+ StreamsPending int
+
+ // MaxConcurrentStreams is how many concurrent streams the
+ // peer advertised as acceptable. Zero means no SETTINGS
+ // frame has been received yet.
+ MaxConcurrentStreams uint32
+
+ // LastIdle, if non-zero, is when the connection last
+ // transitioned to idle state.
+ LastIdle time.Time
+}
+
+// State returns a snapshot of cc's state.
+func (cc *http2ClientConn) State() http2ClientConnState {
+ cc.wmu.Lock()
+ maxConcurrent := cc.maxConcurrentStreams
+ if !cc.seenSettings {
+ maxConcurrent = 0
+ }
+ cc.wmu.Unlock()
+
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return http2ClientConnState{
+ Closed: cc.closed,
+ Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil,
+ StreamsActive: len(cc.streams),
+ StreamsReserved: cc.streamsReserved,
+ StreamsPending: cc.pendingRequests,
+ LastIdle: cc.lastIdle,
+ MaxConcurrentStreams: maxConcurrent,
+ }
+}
+
// clientConnIdleState describes the suitability of a client
// connection to initiate a new RoundTrip request.
type http2clientConnIdleState struct {
canTakeNewRequest bool
- freshConn bool // whether it's unused by any previous request
}
func (cc *http2ClientConn) idleState() http2clientConnIdleState {
@@ -7384,13 +7586,13 @@ func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) {
// writing it.
maxConcurrentOkay = true
} else {
- maxConcurrentOkay = int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams)
+ maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams)
}
st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
+ !cc.doNotReuse &&
int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
!cc.tooIdleLocked()
- st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest
return
}
@@ -7421,7 +7623,7 @@ func (cc *http2ClientConn) onIdleTimeout() {
func (cc *http2ClientConn) closeIfIdle() {
cc.mu.Lock()
- if len(cc.streams) > 0 {
+ if len(cc.streams) > 0 || cc.streamsReserved > 0 {
cc.mu.Unlock()
return
}
@@ -7436,9 +7638,15 @@ func (cc *http2ClientConn) closeIfIdle() {
cc.tconn.Close()
}
+func (cc *http2ClientConn) isDoNotReuseAndIdle() bool {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.doNotReuse && len(cc.streams) == 0
+}
+
var http2shutdownEnterWaitStateHook = func() {}
-// Shutdown gracefully close the client connection, waiting for running streams to complete.
+// Shutdown gracefully closes the client connection, waiting for running streams to complete.
func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
if err := cc.sendGoAway(); err != nil {
return err
@@ -7477,15 +7685,18 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
func (cc *http2ClientConn) sendGoAway() error {
cc.mu.Lock()
- defer cc.mu.Unlock()
- cc.wmu.Lock()
- defer cc.wmu.Unlock()
- if cc.closing {
+ closing := cc.closing
+ cc.closing = true
+ maxStreamID := cc.nextStreamID
+ cc.mu.Unlock()
+ if closing {
// GOAWAY sent already
return nil
}
+
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
// Send a graceful shutdown frame to server
- maxStreamID := cc.nextStreamID
if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil {
return err
}
@@ -7493,7 +7704,6 @@ func (cc *http2ClientConn) sendGoAway() error {
return err
}
// Prevent new requests
- cc.closing = true
return nil
}
@@ -7501,17 +7711,12 @@ func (cc *http2ClientConn) sendGoAway() error {
// err is sent to streams.
func (cc *http2ClientConn) closeForError(err error) error {
cc.mu.Lock()
+ cc.closed = true
+ for _, cs := range cc.streams {
+ cs.abortStreamLocked(err)
+ }
defer cc.cond.Broadcast()
defer cc.mu.Unlock()
- for id, cs := range cc.streams {
- select {
- case cs.resc <- http2resAndError{err: err}:
- default:
- }
- cs.bufPipe.CloseWithError(err)
- delete(cc.streams, id)
- }
- cc.closed = true
return cc.tconn.Close()
}
@@ -7526,47 +7731,10 @@ func (cc *http2ClientConn) Close() error {
// closes the client connection immediately. In-flight requests are interrupted.
func (cc *http2ClientConn) closeForLostPing() error {
err := errors.New("http2: client connection lost")
- return cc.closeForError(err)
-}
-
-const http2maxAllocFrameSize = 512 << 10
-
-// frameBuffer returns a scratch buffer suitable for writing DATA frames.
-// They're capped at the min of the peer's max frame size or 512KB
-// (kinda arbitrarily), but definitely capped so we don't allocate 4GB
-// bufers.
-func (cc *http2ClientConn) frameScratchBuffer() []byte {
- cc.mu.Lock()
- size := cc.maxFrameSize
- if size > http2maxAllocFrameSize {
- size = http2maxAllocFrameSize
- }
- for i, buf := range cc.freeBuf {
- if len(buf) >= int(size) {
- cc.freeBuf[i] = nil
- cc.mu.Unlock()
- return buf[:size]
- }
- }
- cc.mu.Unlock()
- return make([]byte, size)
-}
-
-func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate.
- if len(cc.freeBuf) < maxBufs {
- cc.freeBuf = append(cc.freeBuf, buf)
- return
- }
- for i, old := range cc.freeBuf {
- if old == nil {
- cc.freeBuf[i] = buf
- return
- }
+ if f := cc.t.CountError; f != nil {
+ f("conn_close_lost_ping")
}
- // forget about it.
+ return cc.closeForError(err)
}
// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not
@@ -7630,41 +7798,158 @@ func http2actualContentLength(req *Request) int64 {
return -1
}
+func (cc *http2ClientConn) decrStreamReservations() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.decrStreamReservationsLocked()
+}
+
+func (cc *http2ClientConn) decrStreamReservationsLocked() {
+ if cc.streamsReserved > 0 {
+ cc.streamsReserved--
+ }
+}
+
func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
- resp, _, err := cc.roundTrip(req)
- return resp, err
+ ctx := req.Context()
+ cs := &http2clientStream{
+ cc: cc,
+ ctx: ctx,
+ reqCancel: req.Cancel,
+ isHead: req.Method == "HEAD",
+ reqBody: req.Body,
+ reqBodyContentLength: http2actualContentLength(req),
+ trace: httptrace.ContextClientTrace(ctx),
+ peerClosed: make(chan struct{}),
+ abort: make(chan struct{}),
+ respHeaderRecv: make(chan struct{}),
+ donec: make(chan struct{}),
+ }
+ go cs.doRequest(req)
+
+ waitDone := func() error {
+ select {
+ case <-cs.donec:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
+ }
+ }
+
+ handleResponseHeaders := func() (*Response, error) {
+ res := cs.res
+ if res.StatusCode > 299 {
+ // On error or status code 3xx, 4xx, 5xx, etc abort any
+ // ongoing write, assuming that the server doesn't care
+ // about our request body. If the server replied with 1xx or
+ // 2xx, however, then assume the server DOES potentially
+ // want our body (e.g. full-duplex streaming:
+ // golang.org/issue/13444). If it turns out the server
+ // doesn't, they'll RST_STREAM us soon enough. This is a
+ // heuristic to avoid adding knobs to Transport. Hopefully
+ // we can keep it.
+ cs.abortRequestBodyWrite()
+ }
+ res.Request = req
+ res.TLS = cc.tlsState
+ if res.Body == http2noBody && http2actualContentLength(req) == 0 {
+ // If there isn't a request or response body still being
+ // written, then wait for the stream to be closed before
+ // RoundTrip returns.
+ if err := waitDone(); err != nil {
+ return nil, err
+ }
+ }
+ return res, nil
+ }
+
+ for {
+ select {
+ case <-cs.respHeaderRecv:
+ return handleResponseHeaders()
+ case <-cs.abort:
+ select {
+ case <-cs.respHeaderRecv:
+ // If both cs.respHeaderRecv and cs.abort are signaling,
+ // pick respHeaderRecv. The server probably wrote the
+ // response and immediately reset the stream.
+ // golang.org/issue/49645
+ return handleResponseHeaders()
+ default:
+ waitDone()
+ return nil, cs.abortErr
+ }
+ case <-ctx.Done():
+ err := ctx.Err()
+ cs.abortStream(err)
+ return nil, err
+ case <-cs.reqCancel:
+ cs.abortStream(http2errRequestCanceled)
+ return nil, http2errRequestCanceled
+ }
+ }
}
-func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterReqBodyWrite bool, err error) {
+// doRequest runs for the duration of the request lifetime.
+//
+// It sends the request and performs post-request cleanup (closing Request.Body, etc.).
+func (cs *http2clientStream) doRequest(req *Request) {
+ err := cs.writeRequest(req)
+ cs.cleanupWriteRequest(err)
+}
+
+// writeRequest sends a request.
+//
+// It returns nil after the request is written, the response read,
+// and the request stream is half-closed by the peer.
+//
+// It returns non-nil if the request ends otherwise.
+// If the returned error is StreamError, the error Code may be used in resetting the stream.
+func (cs *http2clientStream) writeRequest(req *Request) (err error) {
+ cc := cs.cc
+ ctx := cs.ctx
+
if err := http2checkConnHeaders(req); err != nil {
- return nil, false, err
- }
- if cc.idleTimer != nil {
- cc.idleTimer.Stop()
+ return err
}
- trailers, err := http2commaSeparatedTrailers(req)
- if err != nil {
- return nil, false, err
+ // Acquire the new-request lock by writing to reqHeaderMu.
+ // This lock guards the critical section covering allocating a new stream ID
+ // (requires mu) and creating the stream (requires wmu).
+ if cc.reqHeaderMu == nil {
+ panic("RoundTrip on uninitialized ClientConn") // for tests
+ }
+ select {
+ case cc.reqHeaderMu <- struct{}{}:
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
+ case <-ctx.Done():
+ return ctx.Err()
}
- hasTrailers := trailers != ""
cc.mu.Lock()
- if err := cc.awaitOpenSlotForRequest(req); err != nil {
+ if cc.idleTimer != nil {
+ cc.idleTimer.Stop()
+ }
+ cc.decrStreamReservationsLocked()
+ if err := cc.awaitOpenSlotForStreamLocked(cs); err != nil {
cc.mu.Unlock()
- return nil, false, err
+ <-cc.reqHeaderMu
+ return err
}
-
- body := req.Body
- contentLen := http2actualContentLength(req)
- hasBody := contentLen != 0
+ cc.addStreamLocked(cs) // assigns stream ID
+ if http2isConnectionCloseRequest(req) {
+ cc.doNotReuse = true
+ }
+ cc.mu.Unlock()
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
- var requestedGzip bool
if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" &&
req.Header.Get("Range") == "" &&
- req.Method != "HEAD" {
+ !cs.isHead {
// Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: https://zlib.net/zlib_faq.html#faq39
@@ -7677,210 +7962,224 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe
// We don't request gzip if the request is for a range, since
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
- requestedGzip = true
+ cs.requestedGzip = true
}
- // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
- // sent by writeRequestBody below, along with any Trailers,
- // again in form HEADERS{1}, CONTINUATION{0,})
- hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen)
- if err != nil {
- cc.mu.Unlock()
- return nil, false, err
+ continueTimeout := cc.t.expectContinueTimeout()
+ if continueTimeout != 0 {
+ if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") {
+ continueTimeout = 0
+ } else {
+ cs.on100 = make(chan struct{}, 1)
+ }
}
- cs := cc.newStream()
- cs.req = req
- cs.trace = httptrace.ContextClientTrace(req.Context())
- cs.requestedGzip = requestedGzip
- bodyWriter := cc.t.getBodyWriterState(cs, body)
- cs.on100 = bodyWriter.on100
+ // Past this point (where we send request headers), it is possible for
+ // RoundTrip to return successfully. Since the RoundTrip contract permits
+ // the caller to "mutate or reuse" the Request after closing the Response's Body,
+ // we must take care when referencing the Request from here on.
+ err = cs.encodeAndWriteHeaders(req)
+ <-cc.reqHeaderMu
+ if err != nil {
+ return err
+ }
- defer func() {
- cc.wmu.Lock()
- werr := cc.werr
- cc.wmu.Unlock()
- if werr != nil {
- cc.Close()
+ hasBody := cs.reqBodyContentLength != 0
+ if !hasBody {
+ cs.sentEndStream = true
+ } else {
+ if continueTimeout != 0 {
+ http2traceWait100Continue(cs.trace)
+ timer := time.NewTimer(continueTimeout)
+ select {
+ case <-timer.C:
+ err = nil
+ case <-cs.on100:
+ err = nil
+ case <-cs.abort:
+ err = cs.abortErr
+ case <-ctx.Done():
+ err = ctx.Err()
+ case <-cs.reqCancel:
+ err = http2errRequestCanceled
+ }
+ timer.Stop()
+ if err != nil {
+ http2traceWroteRequest(cs.trace, err)
+ return err
+ }
}
- }()
-
- cc.wmu.Lock()
- endStream := !hasBody && !hasTrailers
- werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
- cc.wmu.Unlock()
- http2traceWroteHeaders(cs.trace)
- cc.mu.Unlock()
- if werr != nil {
- if hasBody {
- req.Body.Close() // per RoundTripper contract
- bodyWriter.cancel()
+ if err = cs.writeRequestBody(req); err != nil {
+ if err != http2errStopReqBodyWrite {
+ http2traceWroteRequest(cs.trace, err)
+ return err
+ }
+ } else {
+ cs.sentEndStream = true
}
- cc.forgetStreamID(cs.ID)
- // Don't bother sending a RST_STREAM (our write already failed;
- // no need to keep writing)
- http2traceWroteRequest(cs.trace, werr)
- return nil, false, werr
}
+ http2traceWroteRequest(cs.trace, err)
+
var respHeaderTimer <-chan time.Time
- if hasBody {
- bodyWriter.scheduleBodyWrite()
- } else {
- http2traceWroteRequest(cs.trace, nil)
- if d := cc.responseHeaderTimeout(); d != 0 {
- timer := time.NewTimer(d)
- defer timer.Stop()
- respHeaderTimer = timer.C
+ var respHeaderRecv chan struct{}
+ if d := cc.responseHeaderTimeout(); d != 0 {
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+ respHeaderTimer = timer.C
+ respHeaderRecv = cs.respHeaderRecv
+ }
+ // Wait until the peer half-closes its end of the stream,
+ // or until the request is aborted (via context, error, or otherwise),
+ // whichever comes first.
+ for {
+ select {
+ case <-cs.peerClosed:
+ return nil
+ case <-respHeaderTimer:
+ return http2errTimeout
+ case <-respHeaderRecv:
+ respHeaderRecv = nil
+ respHeaderTimer = nil // keep waiting for END_STREAM
+ case <-cs.abort:
+ return cs.abortErr
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
}
}
+}
- readLoopResCh := cs.resc
- bodyWritten := false
- ctx := req.Context()
+func (cs *http2clientStream) encodeAndWriteHeaders(req *Request) error {
+ cc := cs.cc
+ ctx := cs.ctx
- handleReadLoopResponse := func(re http2resAndError) (*Response, bool, error) {
- res := re.res
- if re.err != nil || res.StatusCode > 299 {
- // On error or status code 3xx, 4xx, 5xx, etc abort any
- // ongoing write, assuming that the server doesn't care
- // about our request body. If the server replied with 1xx or
- // 2xx, however, then assume the server DOES potentially
- // want our body (e.g. full-duplex streaming:
- // golang.org/issue/13444). If it turns out the server
- // doesn't, they'll RST_STREAM us soon enough. This is a
- // heuristic to avoid adding knobs to Transport. Hopefully
- // we can keep it.
- bodyWriter.cancel()
- cs.abortRequestBodyWrite(http2errStopReqBodyWrite)
- if hasBody && !bodyWritten {
- <-bodyWriter.resc
- }
- }
- if re.err != nil {
- cc.forgetStreamID(cs.ID)
- return nil, cs.getStartedWrite(), re.err
- }
- res.Request = req
- res.TLS = cc.tlsState
- return res, false, nil
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+
+ // If the request was canceled while waiting for cc.mu, just quit.
+ select {
+ case <-cs.abort:
+ return cs.abortErr
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
+ default:
}
- for {
+ // Encode headers.
+ //
+ // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
+ // sent by writeRequestBody below, along with any Trailers,
+ // again in form HEADERS{1}, CONTINUATION{0,})
+ trailers, err := http2commaSeparatedTrailers(req)
+ if err != nil {
+ return err
+ }
+ hasTrailers := trailers != ""
+ contentLen := http2actualContentLength(req)
+ hasBody := contentLen != 0
+ hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
+ if err != nil {
+ return err
+ }
+
+ // Write the request.
+ endStream := !hasBody && !hasTrailers
+ cs.sentHeaders = true
+ err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
+ http2traceWroteHeaders(cs.trace)
+ return err
+}
+
+// cleanupWriteRequest performs post-request tasks.
+//
+// If err (the result of writeRequest) is non-nil and the stream is not closed,
+// cleanupWriteRequest will send a reset to the peer.
+func (cs *http2clientStream) cleanupWriteRequest(err error) {
+ cc := cs.cc
+
+ if cs.ID == 0 {
+ // We were canceled before creating the stream, so return our reservation.
+ cc.decrStreamReservations()
+ }
+
+ // TODO: write h12Compare test showing whether
+ // Request.Body is closed by the Transport,
+ // and in multiple cases: server replies <=299 and >299
+ // while still writing request body
+ cc.mu.Lock()
+ bodyClosed := cs.reqBodyClosed
+ cs.reqBodyClosed = true
+ cc.mu.Unlock()
+ if !bodyClosed && cs.reqBody != nil {
+ cs.reqBody.Close()
+ }
+
+ if err != nil && cs.sentEndStream {
+ // If the connection is closed immediately after the response is read,
+ // we may be aborted before finishing up here. If the stream was closed
+ // cleanly on both sides, there is no error.
select {
- case re := <-readLoopResCh:
- return handleReadLoopResponse(re)
- case <-respHeaderTimer:
- if !hasBody || bodyWritten {
- cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
- } else {
- bodyWriter.cancel()
- cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
- <-bodyWriter.resc
- }
- cc.forgetStreamID(cs.ID)
- return nil, cs.getStartedWrite(), http2errTimeout
- case <-ctx.Done():
- select {
- case re := <-readLoopResCh:
- return handleReadLoopResponse(re)
- default:
- }
- if !hasBody || bodyWritten {
- cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
- } else {
- bodyWriter.cancel()
- cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
- <-bodyWriter.resc
- }
- cc.forgetStreamID(cs.ID)
- return nil, cs.getStartedWrite(), ctx.Err()
- case <-req.Cancel:
- select {
- case re := <-readLoopResCh:
- return handleReadLoopResponse(re)
- default:
- }
- if !hasBody || bodyWritten {
- cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+ case <-cs.peerClosed:
+ err = nil
+ default:
+ }
+ }
+ if err != nil {
+ cs.abortStream(err) // possibly redundant, but harmless
+ if cs.sentHeaders {
+ if se, ok := err.(http2StreamError); ok {
+ if se.Cause != http2errFromPeer {
+ cc.writeStreamReset(cs.ID, se.Code, err)
+ }
} else {
- bodyWriter.cancel()
- cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
- <-bodyWriter.resc
- }
- cc.forgetStreamID(cs.ID)
- return nil, cs.getStartedWrite(), http2errRequestCanceled
- case <-cs.peerReset:
- select {
- case re := <-readLoopResCh:
- return handleReadLoopResponse(re)
- default:
- }
- // processResetStream already removed the
- // stream from the streams map; no need for
- // forgetStreamID.
- return nil, cs.getStartedWrite(), cs.resetErr
- case err := <-bodyWriter.resc:
- bodyWritten = true
- // Prefer the read loop's response, if available. Issue 16102.
- select {
- case re := <-readLoopResCh:
- return handleReadLoopResponse(re)
- default:
- }
- if err != nil {
- cc.forgetStreamID(cs.ID)
- return nil, cs.getStartedWrite(), err
- }
- if d := cc.responseHeaderTimeout(); d != 0 {
- timer := time.NewTimer(d)
- defer timer.Stop()
- respHeaderTimer = timer.C
+ cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err)
}
}
+ cs.bufPipe.CloseWithError(err) // no-op if already closed
+ } else {
+ if cs.sentHeaders && !cs.sentEndStream {
+ cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil)
+ }
+ cs.bufPipe.CloseWithError(http2errRequestCanceled)
+ }
+ if cs.ID != 0 {
+ cc.forgetStreamID(cs.ID)
+ }
+
+ cc.wmu.Lock()
+ werr := cc.werr
+ cc.wmu.Unlock()
+ if werr != nil {
+ cc.Close()
}
+
+ close(cs.donec)
}
-// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams.
+// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams.
// Must hold cc.mu.
-func (cc *http2ClientConn) awaitOpenSlotForRequest(req *Request) error {
- var waitingForConn chan struct{}
- var waitingForConnErr error // guarded by cc.mu
+func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error {
for {
cc.lastActive = time.Now()
if cc.closed || !cc.canTakeNewRequestLocked() {
- if waitingForConn != nil {
- close(waitingForConn)
- }
return http2errClientConnUnusable
}
cc.lastIdle = time.Time{}
- if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) {
- if waitingForConn != nil {
- close(waitingForConn)
- }
+ if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) {
return nil
}
- // Unfortunately, we cannot wait on a condition variable and channel at
- // the same time, so instead, we spin up a goroutine to check if the
- // request is canceled while we wait for a slot to open in the connection.
- if waitingForConn == nil {
- waitingForConn = make(chan struct{})
- go func() {
- if err := http2awaitRequestCancel(req, waitingForConn); err != nil {
- cc.mu.Lock()
- waitingForConnErr = err
- cc.cond.Broadcast()
- cc.mu.Unlock()
- }
- }()
- }
cc.pendingRequests++
cc.cond.Wait()
cc.pendingRequests--
- if waitingForConnErr != nil {
- return waitingForConnErr
+ select {
+ case <-cs.abort:
+ return cs.abortErr
+ default:
}
}
}
@@ -7907,10 +8206,6 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFram
cc.fr.WriteContinuation(streamID, endHeaders, chunk)
}
}
- // TODO(bradfitz): this Flush could potentially block (as
- // could the WriteHeaders call(s) above), which means they
- // wouldn't respond to Request.Cancel being readable. That's
- // rare, but this should probably be in a goroutine.
cc.bw.Flush()
return cc.werr
}
@@ -7926,32 +8221,59 @@ var (
http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length")
)
-func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
+// frameScratchBufferLen returns the length of a buffer to use for
+// outgoing request bodies to read/write to/from.
+//
+// It returns max(1, min(peer's advertised max frame size,
+// Request.ContentLength+1, 512KB)).
+func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int {
+ const max = 512 << 10
+ n := int64(maxFrameSize)
+ if n > max {
+ n = max
+ }
+ if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n {
+ // Add an extra byte past the declared content-length to
+ // give the caller's Request.Body io.Reader a chance to
+ // give us more bytes than they declared, so we can catch it
+ // early.
+ n = cl + 1
+ }
+ if n < 1 {
+ return 1
+ }
+ return int(n) // doesn't truncate; max is 512K
+}
+
+var http2bufPool sync.Pool // of *[]byte
+
+func (cs *http2clientStream) writeRequestBody(req *Request) (err error) {
cc := cs.cc
+ body := cs.reqBody
sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
- buf := cc.frameScratchBuffer()
- defer cc.putFrameScratchBuffer(buf)
-
- defer func() {
- http2traceWroteRequest(cs.trace, err)
- // TODO: write h12Compare test showing whether
- // Request.Body is closed by the Transport,
- // and in multiple cases: server replies <=299 and >299
- // while still writing request body
- cerr := bodyCloser.Close()
- if err == nil {
- err = cerr
- }
- }()
- req := cs.req
hasTrailers := req.Trailer != nil
- remainLen := http2actualContentLength(req)
+ remainLen := cs.reqBodyContentLength
hasContentLen := remainLen != -1
+ cc.mu.Lock()
+ maxFrameSize := int(cc.maxFrameSize)
+ cc.mu.Unlock()
+
+ // Scratch buffer for reading into & writing from.
+ scratchLen := cs.frameScratchBufferLen(maxFrameSize)
+ var buf []byte
+ if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen {
+ defer http2bufPool.Put(bp)
+ buf = *bp
+ } else {
+ buf = make([]byte, scratchLen)
+ defer http2bufPool.Put(&buf)
+ }
+
var sawEOF bool
for !sawEOF {
- n, err := body.Read(buf[:len(buf)-1])
+ n, err := body.Read(buf[:len(buf)])
if hasContentLen {
remainLen -= int64(n)
if remainLen == 0 && err == nil {
@@ -7962,35 +8284,36 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos
// to send the END_STREAM bit early, double-check that we're actually
// at EOF. Subsequent reads should return (0, EOF) at this point.
// If either value is different, we return an error in one of two ways below.
+ var scratch [1]byte
var n1 int
- n1, err = body.Read(buf[n:])
+ n1, err = body.Read(scratch[:])
remainLen -= int64(n1)
}
if remainLen < 0 {
err = http2errReqBodyTooLong
- cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err)
return err
}
}
- if err == io.EOF {
- sawEOF = true
- err = nil
- } else if err != nil {
- cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err)
- return err
+ if err != nil {
+ cc.mu.Lock()
+ bodyClosed := cs.reqBodyClosed
+ cc.mu.Unlock()
+ switch {
+ case bodyClosed:
+ return http2errStopReqBodyWrite
+ case err == io.EOF:
+ sawEOF = true
+ err = nil
+ default:
+ return err
+ }
}
remain := buf[:n]
for len(remain) > 0 && err == nil {
var allowed int32
allowed, err = cs.awaitFlowControl(len(remain))
- switch {
- case err == http2errStopReqBodyWrite:
- return err
- case err == http2errStopReqBodyWriteAndCancel:
- cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
- return err
- case err != nil:
+ if err != nil {
return err
}
cc.wmu.Lock()
@@ -8021,24 +8344,26 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos
return nil
}
- var trls []byte
- if hasTrailers {
- cc.mu.Lock()
- trls, err = cc.encodeTrailers(req)
- cc.mu.Unlock()
- if err != nil {
- cc.writeStreamReset(cs.ID, http2ErrCodeInternal, err)
- cc.forgetStreamID(cs.ID)
- return err
- }
- }
-
+ // Since the RoundTrip contract permits the caller to "mutate or reuse"
+ // a request after the Response's Body is closed, verify that this hasn't
+ // happened before accessing the trailers.
cc.mu.Lock()
- maxFrameSize := int(cc.maxFrameSize)
+ trailer := req.Trailer
+ err = cs.abortErr
cc.mu.Unlock()
+ if err != nil {
+ return err
+ }
cc.wmu.Lock()
defer cc.wmu.Unlock()
+ var trls []byte
+ if len(trailer) > 0 {
+ trls, err = cc.encodeTrailers(trailer)
+ if err != nil {
+ return err
+ }
+ }
// Two ways to send END_STREAM: either with trailers, or
// with an empty DATA frame.
@@ -8059,17 +8384,24 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos
// if the stream is dead.
func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) {
cc := cs.cc
+ ctx := cs.ctx
cc.mu.Lock()
defer cc.mu.Unlock()
for {
if cc.closed {
return 0, http2errClientConnClosed
}
- if cs.stopReqBody != nil {
- return 0, cs.stopReqBody
+ if cs.reqBodyClosed {
+ return 0, http2errStopReqBodyWrite
}
- if err := cs.checkResetOrDone(); err != nil {
- return 0, err
+ select {
+ case <-cs.abort:
+ return 0, cs.abortErr
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ case <-cs.reqCancel:
+ return 0, http2errRequestCanceled
+ default:
}
if a := cs.flow.available(); a > 0 {
take := a
@@ -8087,9 +8419,14 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er
}
}
-// requires cc.mu be held.
+var http2errNilRequestURL = errors.New("http2: Request.URI is nil")
+
+// requires cc.wmu be held.
func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
cc.hbuf.Reset()
+ if req.URL == nil {
+ return nil, http2errNilRequestURL
+ }
host := req.Host
if host == "" {
@@ -8275,12 +8612,12 @@ func http2shouldSendReqContentLength(method string, contentLength int64) bool {
}
}
-// requires cc.mu be held.
-func (cc *http2ClientConn) encodeTrailers(req *Request) ([]byte, error) {
+// requires cc.wmu be held.
+func (cc *http2ClientConn) encodeTrailers(trailer Header) ([]byte, error) {
cc.hbuf.Reset()
hlSize := uint64(0)
- for k, vv := range req.Trailer {
+ for k, vv := range trailer {
for _, v := range vv {
hf := hpack.HeaderField{Name: k, Value: v}
hlSize += uint64(hf.Size())
@@ -8290,7 +8627,7 @@ func (cc *http2ClientConn) encodeTrailers(req *Request) ([]byte, error) {
return nil, http2errRequestHeaderListSize
}
- for k, vv := range req.Trailer {
+ for k, vv := range trailer {
lowKey, ascii := http2asciiToLower(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
@@ -8320,51 +8657,51 @@ type http2resAndError struct {
}
// requires cc.mu be held.
-func (cc *http2ClientConn) newStream() *http2clientStream {
- cs := &http2clientStream{
- cc: cc,
- ID: cc.nextStreamID,
- resc: make(chan http2resAndError, 1),
- peerReset: make(chan struct{}),
- done: make(chan struct{}),
- }
+func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) {
cs.flow.add(int32(cc.initialWindowSize))
cs.flow.setConnFlow(&cc.flow)
cs.inflow.add(http2transportDefaultStreamFlow)
cs.inflow.setConnFlow(&cc.inflow)
+ cs.ID = cc.nextStreamID
cc.nextStreamID += 2
cc.streams[cs.ID] = cs
- return cs
+ if cs.ID == 0 {
+ panic("assigned stream ID 0")
+ }
}
func (cc *http2ClientConn) forgetStreamID(id uint32) {
- cc.streamByID(id, true)
-}
-
-func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream {
cc.mu.Lock()
- defer cc.mu.Unlock()
- cs := cc.streams[id]
- if andRemove && cs != nil && !cc.closed {
- cc.lastActive = time.Now()
- delete(cc.streams, id)
- if len(cc.streams) == 0 && cc.idleTimer != nil {
- cc.idleTimer.Reset(cc.idleTimeout)
- cc.lastIdle = time.Now()
- }
- close(cs.done)
- // Wake up checkResetOrDone via clientStream.awaitFlowControl and
- // wake up RoundTrip if there is a pending request.
- cc.cond.Broadcast()
+ slen := len(cc.streams)
+ delete(cc.streams, id)
+ if len(cc.streams) != slen-1 {
+ panic("forgetting unknown stream id")
+ }
+ cc.lastActive = time.Now()
+ if len(cc.streams) == 0 && cc.idleTimer != nil {
+ cc.idleTimer.Reset(cc.idleTimeout)
+ cc.lastIdle = time.Now()
+ }
+ // Wake up writeRequestBody via clientStream.awaitFlowControl and
+ // wake up RoundTrip if there is a pending request.
+ cc.cond.Broadcast()
+
+ closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives()
+ if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
+ if http2VerboseLogs {
+ cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2)
+ }
+ cc.closed = true
+ defer cc.tconn.Close()
}
- return cs
+
+ cc.mu.Unlock()
}
// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
type http2clientConnReadLoop struct {
- _ http2incomparable
- cc *http2ClientConn
- closeWhenIdle bool
+ _ http2incomparable
+ cc *http2ClientConn
}
// readLoop runs in its own goroutine and reads and dispatches frames.
@@ -8424,23 +8761,49 @@ func (rl *http2clientConnReadLoop) cleanup() {
} else if err == io.EOF {
err = io.ErrUnexpectedEOF
}
+ cc.closed = true
for _, cs := range cc.streams {
- cs.bufPipe.CloseWithError(err) // no-op if already closed
select {
- case cs.resc <- http2resAndError{err: err}:
+ case <-cs.peerClosed:
+ // The server closed the stream before closing the conn,
+ // so no need to interrupt it.
default:
+ cs.abortStreamLocked(err)
}
- close(cs.done)
}
- cc.closed = true
cc.cond.Broadcast()
cc.mu.Unlock()
}
+// countReadFrameError calls Transport.CountError with a string
+// representing err.
+func (cc *http2ClientConn) countReadFrameError(err error) {
+ f := cc.t.CountError
+ if f == nil || err == nil {
+ return
+ }
+ if ce, ok := err.(http2ConnectionError); ok {
+ errCode := http2ErrCode(ce)
+ f(fmt.Sprintf("read_frame_conn_error_%s", errCode.stringToken()))
+ return
+ }
+ if errors.Is(err, io.EOF) {
+ f("read_frame_eof")
+ return
+ }
+ if errors.Is(err, io.ErrUnexpectedEOF) {
+ f("read_frame_unexpected_eof")
+ return
+ }
+ if errors.Is(err, http2ErrFrameTooLarge) {
+ f("read_frame_too_large")
+ return
+ }
+ f("read_frame_other")
+}
+
func (rl *http2clientConnReadLoop) run() error {
cc := rl.cc
- rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse
- gotReply := false // ever saw a HEADERS reply
gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout
var t *time.Timer
@@ -8457,9 +8820,7 @@ func (rl *http2clientConnReadLoop) run() error {
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
}
if se, ok := err.(http2StreamError); ok {
- if cs := cc.streamByID(se.StreamID, false); cs != nil {
- cs.cc.writeStreamReset(cs.ID, se.Code, err)
- cs.cc.forgetStreamID(cs.ID)
+ if cs := rl.streamByID(se.StreamID); cs != nil {
if se.Cause == nil {
se.Cause = cc.fr.errDetail
}
@@ -8467,6 +8828,7 @@ func (rl *http2clientConnReadLoop) run() error {
}
continue
} else if err != nil {
+ cc.countReadFrameError(err)
return err
}
if http2VerboseLogs {
@@ -8479,22 +8841,16 @@ func (rl *http2clientConnReadLoop) run() error {
}
gotSettings = true
}
- maybeIdle := false // whether frame might transition us to idle
switch f := f.(type) {
case *http2MetaHeadersFrame:
err = rl.processHeaders(f)
- maybeIdle = true
- gotReply = true
case *http2DataFrame:
err = rl.processData(f)
- maybeIdle = true
case *http2GoAwayFrame:
err = rl.processGoAway(f)
- maybeIdle = true
case *http2RSTStreamFrame:
err = rl.processResetStream(f)
- maybeIdle = true
case *http2SettingsFrame:
err = rl.processSettings(f)
case *http2PushPromiseFrame:
@@ -8512,38 +8868,24 @@ func (rl *http2clientConnReadLoop) run() error {
}
return err
}
- if rl.closeWhenIdle && gotReply && maybeIdle {
- cc.closeIfIdle()
- }
}
}
func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error {
- cc := rl.cc
- cs := cc.streamByID(f.StreamID, false)
+ cs := rl.streamByID(f.StreamID)
if cs == nil {
// We'd get here if we canceled a request while the
// server had its response still in flight. So if this
// was just something we canceled, ignore it.
return nil
}
- if f.StreamEnded() {
- // Issue 20521: If the stream has ended, streamByID() causes
- // clientStream.done to be closed, which causes the request's bodyWriter
- // to be closed with an errStreamClosed, which may be received by
- // clientConn.RoundTrip before the result of processing these headers.
- // Deferring stream closure allows the header processing to occur first.
- // clientConn.RoundTrip may still receive the bodyWriter error first, but
- // the fix for issue 16102 prioritises any response.
- //
- // Issue 22413: If there is no request body, we should close the
- // stream before writing to cs.resc so that the stream is closed
- // immediately once RoundTrip returns.
- if cs.req.Body != nil {
- defer cc.forgetStreamID(f.StreamID)
- } else {
- cc.forgetStreamID(f.StreamID)
- }
+ if cs.readClosed {
+ rl.endStreamError(cs, http2StreamError{
+ StreamID: f.StreamID,
+ Code: http2ErrCodeProtocol,
+ Cause: errors.New("protocol error: headers after END_STREAM"),
+ })
+ return nil
}
if !cs.firstByte {
if cs.trace != nil {
@@ -8567,9 +8909,11 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro
return err
}
// Any other error type is a stream error.
- cs.cc.writeStreamReset(f.StreamID, http2ErrCodeProtocol, err)
- cc.forgetStreamID(cs.ID)
- cs.resc <- http2resAndError{err: err}
+ rl.endStreamError(cs, http2StreamError{
+ StreamID: f.StreamID,
+ Code: http2ErrCodeProtocol,
+ Cause: err,
+ })
return nil // return nil from process* funcs to keep conn alive
}
if res == nil {
@@ -8577,7 +8921,11 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro
return nil
}
cs.resTrailer = &res.Trailer
- cs.resc <- http2resAndError{res: res}
+ cs.res = res
+ close(cs.respHeaderRecv)
+ if f.StreamEnded() {
+ rl.endStream(cs)
+ }
return nil
}
@@ -8639,6 +8987,9 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http
}
if statusCode >= 100 && statusCode <= 199 {
+ if f.StreamEnded() {
+ return nil, errors.New("1xx informational response with END_STREAM flag")
+ }
cs.num1xx++
const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http
if cs.num1xx > max1xxResponses {
@@ -8651,42 +9002,49 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http
}
if statusCode == 100 {
http2traceGot100Continue(cs.trace)
- if cs.on100 != nil {
- cs.on100() // forces any write delay timer to fire
+ select {
+ case cs.on100 <- struct{}{}:
+ default:
}
}
cs.pastHeaders = false // do it all again
return nil, nil
}
- streamEnded := f.StreamEnded()
- isHead := cs.req.Method == "HEAD"
- if !streamEnded || isHead {
- res.ContentLength = -1
- if clens := res.Header["Content-Length"]; len(clens) == 1 {
- if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil {
- res.ContentLength = int64(cl)
- } else {
- // TODO: care? unlike http/1, it won't mess up our framing, so it's
- // more safe smuggling-wise to ignore.
- }
- } else if len(clens) > 1 {
+ res.ContentLength = -1
+ if clens := res.Header["Content-Length"]; len(clens) == 1 {
+ if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil {
+ res.ContentLength = int64(cl)
+ } else {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
}
+ } else if len(clens) > 1 {
+ // TODO: care? unlike http/1, it won't mess up our framing, so it's
+ // more safe smuggling-wise to ignore.
+ } else if f.StreamEnded() && !cs.isHead {
+ res.ContentLength = 0
}
- if streamEnded || isHead {
+ if cs.isHead {
res.Body = http2noBody
return res, nil
}
- cs.bufPipe = http2pipe{b: &http2dataBuffer{expected: res.ContentLength}}
+ if f.StreamEnded() {
+ if res.ContentLength > 0 {
+ res.Body = http2missingBody{}
+ } else {
+ res.Body = http2noBody
+ }
+ return res, nil
+ }
+
+ cs.bufPipe.setBuffer(&http2dataBuffer{expected: res.ContentLength})
cs.bytesRemain = res.ContentLength
res.Body = http2transportResponseBody{cs}
- go cs.awaitRequestCancel(cs.req)
- if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
+ if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
@@ -8725,8 +9083,7 @@ func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *htt
}
// transportResponseBody is the concrete type of Transport.RoundTrip's
-// Response.Body. It is an io.ReadCloser. On Read, it reads from cs.body.
-// On Close it sends RST_STREAM if EOF wasn't already seen.
+// Response.Body. It is an io.ReadCloser.
type http2transportResponseBody struct {
cs *http2clientStream
}
@@ -8744,7 +9101,7 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) {
n = int(cs.bytesRemain)
if err == nil {
err = errors.New("net/http: server replied with more than declared Content-Length; truncated")
- cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, err)
+ cs.abortStream(err)
}
cs.readErr = err
return int(cs.bytesRemain), err
@@ -8762,8 +9119,6 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) {
}
cc.mu.Lock()
- defer cc.mu.Unlock()
-
var connAdd, streamAdd int32
// Check the conn-level first, before the stream-level.
if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 {
@@ -8780,6 +9135,8 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) {
cs.inflow.add(streamAdd)
}
}
+ cc.mu.Unlock()
+
if connAdd != 0 || streamAdd != 0 {
cc.wmu.Lock()
defer cc.wmu.Unlock()
@@ -8800,34 +9157,45 @@ func (b http2transportResponseBody) Close() error {
cs := b.cs
cc := cs.cc
- serverSentStreamEnd := cs.bufPipe.Err() == io.EOF
unread := cs.bufPipe.Len()
-
- if unread > 0 || !serverSentStreamEnd {
+ if unread > 0 {
cc.mu.Lock()
- cc.wmu.Lock()
- if !serverSentStreamEnd {
- cc.fr.WriteRSTStream(cs.ID, http2ErrCodeCancel)
- cs.didReset = true
- }
// Return connection-level flow control.
if unread > 0 {
cc.inflow.add(int32(unread))
+ }
+ cc.mu.Unlock()
+
+ // TODO(dneil): Acquiring this mutex can block indefinitely.
+ // Move flow control return to a goroutine?
+ cc.wmu.Lock()
+ // Return connection-level flow control.
+ if unread > 0 {
cc.fr.WriteWindowUpdate(0, uint32(unread))
}
cc.bw.Flush()
cc.wmu.Unlock()
- cc.mu.Unlock()
}
cs.bufPipe.BreakWithError(http2errClosedResponseBody)
- cc.forgetStreamID(cs.ID)
+ cs.abortStream(http2errClosedResponseBody)
+
+ select {
+ case <-cs.donec:
+ case <-cs.ctx.Done():
+ // See golang/go#49366: The net/http package can cancel the
+ // request context after the response body is fully read.
+ // Don't treat this as an error.
+ return nil
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
+ }
return nil
}
func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
cc := rl.cc
- cs := cc.streamByID(f.StreamID, f.StreamEnded())
+ cs := rl.streamByID(f.StreamID)
data := f.Data()
if cs == nil {
cc.mu.Lock()
@@ -8856,6 +9224,14 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
}
return nil
}
+ if cs.readClosed {
+ cc.logf("protocol error: received DATA after END_STREAM")
+ rl.endStreamError(cs, http2StreamError{
+ StreamID: f.StreamID,
+ Code: http2ErrCodeProtocol,
+ })
+ return nil
+ }
if !cs.firstByte {
cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, http2StreamError{
@@ -8865,7 +9241,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
return nil
}
if f.Length > 0 {
- if cs.req.Method == "HEAD" && len(data) > 0 {
+ if cs.isHead && len(data) > 0 {
cc.logf("protocol error: received DATA on a HEAD request")
rl.endStreamError(cs, http2StreamError{
StreamID: f.StreamID,
@@ -8887,30 +9263,39 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
if pad := int(f.Length) - len(data); pad > 0 {
refund += pad
}
- // Return len(data) now if the stream is already closed,
- // since data will never be read.
- didReset := cs.didReset
- if didReset {
- refund += len(data)
+
+ didReset := false
+ var err error
+ if len(data) > 0 {
+ if _, err = cs.bufPipe.Write(data); err != nil {
+ // Return len(data) now if the stream is already closed,
+ // since data will never be read.
+ didReset = true
+ refund += len(data)
+ }
}
+
if refund > 0 {
cc.inflow.add(int32(refund))
+ if !didReset {
+ cs.inflow.add(int32(refund))
+ }
+ }
+ cc.mu.Unlock()
+
+ if refund > 0 {
cc.wmu.Lock()
cc.fr.WriteWindowUpdate(0, uint32(refund))
if !didReset {
- cs.inflow.add(int32(refund))
cc.fr.WriteWindowUpdate(cs.ID, uint32(refund))
}
cc.bw.Flush()
cc.wmu.Unlock()
}
- cc.mu.Unlock()
- if len(data) > 0 && !didReset {
- if _, err := cs.bufPipe.Write(data); err != nil {
- rl.endStreamError(cs, err)
- return err
- }
+ if err != nil {
+ rl.endStreamError(cs, err)
+ return nil
}
}
@@ -8923,24 +9308,32 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) {
// TODO: check that any declared content-length matches, like
// server.go's (*stream).endStream method.
- rl.endStreamError(cs, nil)
+ if !cs.readClosed {
+ cs.readClosed = true
+ // Close cs.bufPipe and cs.peerClosed with cc.mu held to avoid a
+ // race condition: The caller can read io.EOF from Response.Body
+ // and close the body before we close cs.peerClosed, causing
+ // cleanupWriteRequest to send a RST_STREAM.
+ rl.cc.mu.Lock()
+ defer rl.cc.mu.Unlock()
+ cs.bufPipe.closeWithErrorAndCode(io.EOF, cs.copyTrailers)
+ close(cs.peerClosed)
+ }
}
func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) {
- var code func()
- if err == nil {
- err = io.EOF
- code = cs.copyTrailers
- }
- if http2isConnectionCloseRequest(cs.req) {
- rl.closeWhenIdle = true
- }
- cs.bufPipe.closeWithErrorAndCode(err, code)
+ cs.readAborted = true
+ cs.abortStream(err)
+}
- select {
- case cs.resc <- http2resAndError{err: err}:
- default:
+func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream {
+ rl.cc.mu.Lock()
+ defer rl.cc.mu.Unlock()
+ cs := rl.cc.streams[id]
+ if cs != nil && !cs.readAborted {
+ return cs
}
+ return nil
}
func (cs *http2clientStream) copyTrailers() {
@@ -8959,6 +9352,10 @@ func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error {
if f.ErrCode != 0 {
// TODO: deal with GOAWAY more. particularly the error code
cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode)
+ if fn := cc.t.CountError; fn != nil {
+ fn("recv_goaway_" + f.ErrCode.stringToken())
+ }
+
}
cc.setGoAway(f)
return nil
@@ -8966,6 +9363,23 @@ func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error {
func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error {
cc := rl.cc
+ // Locking both mu and wmu here allows frame encoding to read settings with only wmu held.
+ // Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless.
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+
+ if err := rl.processSettingsNoWrite(f); err != nil {
+ return err
+ }
+ if !f.IsAck() {
+ cc.fr.WriteSettingsAck()
+ cc.bw.Flush()
+ }
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error {
+ cc := rl.cc
cc.mu.Lock()
defer cc.mu.Unlock()
@@ -8977,12 +9391,14 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error
return http2ConnectionError(http2ErrCodeProtocol)
}
+ var seenMaxConcurrentStreams bool
err := f.ForeachSetting(func(s http2Setting) error {
switch s.ID {
case http2SettingMaxFrameSize:
cc.maxFrameSize = s.Val
case http2SettingMaxConcurrentStreams:
cc.maxConcurrentStreams = s.Val
+ seenMaxConcurrentStreams = true
case http2SettingMaxHeaderListSize:
cc.peerMaxHeaderListSize = uint64(s.Val)
case http2SettingInitialWindowSize:
@@ -9014,17 +9430,23 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error
return err
}
- cc.wmu.Lock()
- defer cc.wmu.Unlock()
+ if !cc.seenSettings {
+ if !seenMaxConcurrentStreams {
+ // This was the servers initial SETTINGS frame and it
+ // didn't contain a MAX_CONCURRENT_STREAMS field so
+ // increase the number of concurrent streams this
+ // connection can establish to our default.
+ cc.maxConcurrentStreams = http2defaultMaxConcurrentStreams
+ }
+ cc.seenSettings = true
+ }
- cc.fr.WriteSettingsAck()
- cc.bw.Flush()
- return cc.werr
+ return nil
}
func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error {
cc := rl.cc
- cs := cc.streamByID(f.StreamID, false)
+ cs := rl.streamByID(f.StreamID)
if f.StreamID != 0 && cs == nil {
return nil
}
@@ -9044,24 +9466,22 @@ func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame
}
func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error {
- cs := rl.cc.streamByID(f.StreamID, true)
+ cs := rl.streamByID(f.StreamID)
if cs == nil {
- // TODO: return error if server tries to RST_STEAM an idle stream
+ // TODO: return error if server tries to RST_STREAM an idle stream
return nil
}
- select {
- case <-cs.peerReset:
- // Already reset.
- // This is the only goroutine
- // which closes this, so there
- // isn't a race.
- default:
- err := http2streamError(cs.ID, f.ErrCode)
- cs.resetErr = err
- close(cs.peerReset)
- cs.bufPipe.CloseWithError(err)
- cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl
+ serr := http2streamError(cs.ID, f.ErrCode)
+ serr.Cause = http2errFromPeer
+ if f.ErrCode == http2ErrCodeProtocol {
+ rl.cc.SetDoNotReuse()
}
+ if fn := cs.cc.t.CountError; fn != nil {
+ fn("recv_rststream_" + f.ErrCode.stringToken())
+ }
+ cs.abortStream(serr)
+
+ cs.bufPipe.CloseWithError(serr)
return nil
}
@@ -9083,19 +9503,24 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error {
}
cc.mu.Unlock()
}
- cc.wmu.Lock()
- if err := cc.fr.WritePing(false, p); err != nil {
- cc.wmu.Unlock()
- return err
- }
- if err := cc.bw.Flush(); err != nil {
- cc.wmu.Unlock()
- return err
- }
- cc.wmu.Unlock()
+ errc := make(chan error, 1)
+ go func() {
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ if err := cc.fr.WritePing(false, p); err != nil {
+ errc <- err
+ return
+ }
+ if err := cc.bw.Flush(); err != nil {
+ errc <- err
+ return
+ }
+ }()
select {
case <-c:
return nil
+ case err := <-errc:
+ return err
case <-ctx.Done():
return ctx.Err()
case <-cc.readerDone:
@@ -9172,6 +9597,12 @@ func (t *http2Transport) logf(format string, args ...interface{}) {
var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil))
+type http2missingBody struct{}
+
+func (http2missingBody) Close() error { return nil }
+
+func (http2missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF }
+
func http2strSliceContains(ss []string, s string) bool {
for _, v := range ss {
if v == s {
@@ -9218,87 +9649,6 @@ type http2errorReader struct{ err error }
func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err }
-// bodyWriterState encapsulates various state around the Transport's writing
-// of the request body, particularly regarding doing delayed writes of the body
-// when the request contains "Expect: 100-continue".
-type http2bodyWriterState struct {
- cs *http2clientStream
- timer *time.Timer // if non-nil, we're doing a delayed write
- fnonce *sync.Once // to call fn with
- fn func() // the code to run in the goroutine, writing the body
- resc chan error // result of fn's execution
- delay time.Duration // how long we should delay a delayed write for
-}
-
-func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reader) (s http2bodyWriterState) {
- s.cs = cs
- if body == nil {
- return
- }
- resc := make(chan error, 1)
- s.resc = resc
- s.fn = func() {
- cs.cc.mu.Lock()
- cs.startedWrite = true
- cs.cc.mu.Unlock()
- resc <- cs.writeRequestBody(body, cs.req.Body)
- }
- s.delay = t.expectContinueTimeout()
- if s.delay == 0 ||
- !httpguts.HeaderValuesContainsToken(
- cs.req.Header["Expect"],
- "100-continue") {
- return
- }
- s.fnonce = new(sync.Once)
-
- // Arm the timer with a very large duration, which we'll
- // intentionally lower later. It has to be large now because
- // we need a handle to it before writing the headers, but the
- // s.delay value is defined to not start until after the
- // request headers were written.
- const hugeDuration = 365 * 24 * time.Hour
- s.timer = time.AfterFunc(hugeDuration, func() {
- s.fnonce.Do(s.fn)
- })
- return
-}
-
-func (s http2bodyWriterState) cancel() {
- if s.timer != nil {
- if s.timer.Stop() {
- s.resc <- nil
- }
- }
-}
-
-func (s http2bodyWriterState) on100() {
- if s.timer == nil {
- // If we didn't do a delayed write, ignore the server's
- // bogus 100 continue response.
- return
- }
- s.timer.Stop()
- go func() { s.fnonce.Do(s.fn) }()
-}
-
-// scheduleBodyWrite starts writing the body, either immediately (in
-// the common case) or after the delay timeout. It should not be
-// called until after the headers have been written.
-func (s http2bodyWriterState) scheduleBodyWrite() {
- if s.timer == nil {
- // We're not doing a delayed write (see
- // getBodyWriterState), so just start the writing
- // goroutine immediately.
- go s.fn()
- return
- }
- http2traceWait100Continue(s.cs.trace)
- if s.timer.Stop() {
- s.timer.Reset(s.delay)
- }
-}
-
// isConnectionCloseRequest reports whether req should use its own
// connection for a single request and then close the connection.
func http2isConnectionCloseRequest(req *Request) bool {
@@ -9775,7 +10125,8 @@ type http2WriteScheduler interface {
// Pop dequeues the next frame to write. Returns false if no frames can
// be written. Frames with a given wr.StreamID() are Pop'd in the same
- // order they are Push'd. No frames should be discarded except by CloseStream.
+ // order they are Push'd, except RST_STREAM frames. No frames should be
+ // discarded except by CloseStream.
Pop() (wr http2FrameWriteRequest, ok bool)
}
@@ -9795,6 +10146,7 @@ type http2FrameWriteRequest struct {
// stream is the stream on which this frame will be written.
// nil for non-stream frames like PING and SETTINGS.
+ // nil for RST_STREAM streams, which use the StreamError.StreamID field instead.
stream *http2stream
// done, if non-nil, must be a buffered channel with space for
@@ -10474,11 +10826,11 @@ func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http
}
func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) {
- id := wr.StreamID()
- if id == 0 {
+ if wr.isControl() {
ws.zero.push(wr)
return
}
+ id := wr.StreamID()
q, ok := ws.sq[id]
if !ok {
q = ws.queuePool.get()
@@ -10488,7 +10840,7 @@ func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) {
}
func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) {
- // Control frames first.
+ // Control and RST_STREAM frames first.
if !ws.zero.empty() {
return ws.zero.shift(), true
}
diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go
index 4c72dcb..6487e50 100644
--- a/libgo/go/net/http/header.go
+++ b/libgo/go/net/http/header.go
@@ -13,6 +13,8 @@ import (
"strings"
"sync"
"time"
+
+ "golang.org/x/net/http/httpguts"
)
// A Header represents the key-value pairs in an HTTP header.
@@ -155,7 +157,7 @@ func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kv
func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key }
var headerSorterPool = sync.Pool{
- New: func() interface{} { return new(headerSorter) },
+ New: func() any { return new(headerSorter) },
}
// sortedKeyValues returns h's keys sorted in the returned kvs
@@ -192,6 +194,13 @@ func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptra
kvs, sorter := h.sortedKeyValues(exclude)
var formattedVals []string
for _, kv := range kvs {
+ if !httpguts.ValidHeaderFieldName(kv.key) {
+ // This could be an error. In the common case of
+ // writing response headers, however, we have no good
+ // way to provide the error back to the server
+ // handler, so just drop invalid headers instead.
+ continue
+ }
for _, v := range kv.values {
v = headerNewlineToSpace.Replace(v)
v = textproto.TrimString(v)
diff --git a/libgo/go/net/http/header_test.go b/libgo/go/net/http/header_test.go
index ad8ab9b..575493b 100644
--- a/libgo/go/net/http/header_test.go
+++ b/libgo/go/net/http/header_test.go
@@ -89,6 +89,19 @@ var headerWriteTests = []struct {
"k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" +
"k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n",
},
+ // Tests invalid characters in headers.
+ {
+ Header{
+ "Content-Type": {"text/html; charset=UTF-8"},
+ "NewlineInValue": {"1\r\nBar: 2"},
+ "NewlineInKey\r\n": {"1"},
+ "Colon:InKey": {"1"},
+ "Evil: 1\r\nSmuggledValue": {"1"},
+ },
+ nil,
+ "Content-Type: text/html; charset=UTF-8\r\n" +
+ "NewlineInValue: 1 Bar: 2\r\n",
+ },
}
func TestHeaderWrite(t *testing.T) {
diff --git a/libgo/go/net/http/httptrace/trace.go b/libgo/go/net/http/httptrace/trace.go
index 5777c91..6af30f7 100644
--- a/libgo/go/net/http/httptrace/trace.go
+++ b/libgo/go/net/http/httptrace/trace.go
@@ -50,7 +50,7 @@ func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context {
}
}
if trace.DNSDone != nil {
- nt.DNSDone = func(netIPs []interface{}, coalesced bool, err error) {
+ nt.DNSDone = func(netIPs []any, coalesced bool, err error) {
addrs := make([]net.IPAddr, len(netIPs))
for i, ip := range netIPs {
addrs[i] = ip.(net.IPAddr)
diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go
index 2948f27..d7baecd 100644
--- a/libgo/go/net/http/httputil/dump.go
+++ b/libgo/go/net/http/httputil/dump.go
@@ -292,7 +292,7 @@ func DumpRequest(req *http.Request, body bool) ([]byte, error) {
// can detect that the lack of body was intentional.
var errNoBody = errors.New("sentinel error value")
-// failureToReadBody is a io.ReadCloser that just returns errNoBody on
+// failureToReadBody is an io.ReadCloser that just returns errNoBody on
// Read. It's swapped in when we don't actually want to consume
// the body, but need a non-nil one, and want to distinguish the
// error from reading the dummy body.
diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go
index 366cc82..5df2ee8 100644
--- a/libgo/go/net/http/httputil/dump_test.go
+++ b/libgo/go/net/http/httputil/dump_test.go
@@ -31,7 +31,7 @@ type dumpTest struct {
Req *http.Request
GetReq func() *http.Request
- Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body
+ Body any // optional []byte or func() io.ReadCloser to populate Req.Body
WantDump string
WantDumpOut string
diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go
index 8b63368..319e2a3 100644
--- a/libgo/go/net/http/httputil/reverseproxy.go
+++ b/libgo/go/net/http/httputil/reverseproxy.go
@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"log"
+ "mime"
"net"
"net/http"
"net/http/internal/ascii"
@@ -412,7 +413,7 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
// For Server-Sent Events responses, flush immediately.
// The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
- if resCT == "text/event-stream" {
+ if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
return -1 // negative means immediately
}
@@ -483,7 +484,7 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int
}
}
-func (p *ReverseProxy) logf(format string, args ...interface{}) {
+func (p *ReverseProxy) logf(format string, args ...any) {
if p.ErrorLog != nil {
p.ErrorLog.Printf(format, args...)
} else {
diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go
index 4b6ad77..90e8903 100644
--- a/libgo/go/net/http/httputil/reverseproxy_test.go
+++ b/libgo/go/net/http/httputil/reverseproxy_test.go
@@ -1195,6 +1195,26 @@ func TestSelectFlushInterval(t *testing.T) {
want: -1,
},
{
+ name: "server-sent events with media-type parameters overrides non-zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream;charset=utf-8"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "server-sent events with media-type parameters overrides zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream;charset=utf-8"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ {
name: "Content-Length: -1, overrides non-zero",
res: &http.Response{
ContentLength: -1,
diff --git a/libgo/go/net/http/internal/chunked.go b/libgo/go/net/http/internal/chunked.go
index f06e572..37a72e9 100644
--- a/libgo/go/net/http/internal/chunked.go
+++ b/libgo/go/net/http/internal/chunked.go
@@ -81,6 +81,11 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
cr.err = errors.New("malformed chunked encoding")
break
}
+ } else {
+ if cr.err == io.EOF {
+ cr.err = io.ErrUnexpectedEOF
+ }
+ break
}
cr.checkEnd = false
}
@@ -109,6 +114,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
// bytes to verify they are "\r\n".
if cr.n == 0 && cr.err == nil {
cr.checkEnd = true
+ } else if cr.err == io.EOF {
+ cr.err = io.ErrUnexpectedEOF
}
}
return n, cr.err
@@ -152,6 +159,8 @@ func isASCIISpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
+var semi = []byte(";")
+
// removeChunkExtension removes any chunk-extension from p.
// For example,
// "0" => "0"
@@ -159,14 +168,11 @@ func isASCIISpace(b byte) bool {
// "0;token=val" => "0"
// `0;token="quoted string"` => "0"
func removeChunkExtension(p []byte) ([]byte, error) {
- semi := bytes.IndexByte(p, ';')
- if semi == -1 {
- return p, nil
- }
+ p, _, _ = bytes.Cut(p, semi)
// TODO: care about exact syntax of chunk extensions? We're
// ignoring and stripping them anyway. For now just never
// return an error.
- return p[:semi], nil
+ return p, nil
}
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go
index 08152ed..5e29a78 100644
--- a/libgo/go/net/http/internal/chunked_test.go
+++ b/libgo/go/net/http/internal/chunked_test.go
@@ -11,6 +11,7 @@ import (
"io"
"strings"
"testing"
+ "testing/iotest"
)
func TestChunk(t *testing.T) {
@@ -211,3 +212,30 @@ func TestChunkReadPartial(t *testing.T) {
}
}
+
+// Issue 48861: ChunkedReader should report incomplete chunks
+func TestIncompleteChunk(t *testing.T) {
+ const valid = "4\r\nabcd\r\n" + "5\r\nabc\r\n\r\n" + "0\r\n"
+
+ for i := 0; i < len(valid); i++ {
+ incomplete := valid[:i]
+ r := NewChunkedReader(strings.NewReader(incomplete))
+ if _, err := io.ReadAll(r); err != io.ErrUnexpectedEOF {
+ t.Errorf("expected io.ErrUnexpectedEOF for %q, got %v", incomplete, err)
+ }
+ }
+
+ r := NewChunkedReader(strings.NewReader(valid))
+ if _, err := io.ReadAll(r); err != nil {
+ t.Errorf("unexpected error for %q: %v", valid, err)
+ }
+}
+
+func TestChunkEndReadError(t *testing.T) {
+ readErr := fmt.Errorf("chunk end read error")
+
+ r := NewChunkedReader(io.MultiReader(strings.NewReader("4\r\nabcd"), iotest.ErrReader(readErr)))
+ if _, err := io.ReadAll(r); err != readErr {
+ t.Errorf("expected %v, got %v", readErr, err)
+ }
+}
diff --git a/libgo/go/net/http/internal/testcert/testcert.go b/libgo/go/net/http/internal/testcert/testcert.go
index 5f94704..d510e79 100644
--- a/libgo/go/net/http/internal/testcert/testcert.go
+++ b/libgo/go/net/http/internal/testcert/testcert.go
@@ -10,37 +10,56 @@ import "strings"
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
// generated from src/crypto/tls:
-// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+// go run generate_cert.go --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var LocalhostCert = []byte(`-----BEGIN CERTIFICATE-----
-MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
+MIIDOTCCAiGgAwIBAgIQSRJrEpBGFc7tNb1fb5pKFzANBgkqhkiG9w0BAQsFADAS
MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
-MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
-iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
-iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
-rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
-BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
-AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
-AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
-tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
-h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
-fblo6RBxUQ==
+MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
+MIIBCgKCAQEA6Gba5tHV1dAKouAaXO3/ebDUU4rvwCUg/CNaJ2PT5xLD4N1Vcb8r
+bFSW2HXKq+MPfVdwIKR/1DczEoAGf/JWQTW7EgzlXrCd3rlajEX2D73faWJekD0U
+aUgz5vtrTXZ90BQL7WvRICd7FlEZ6FPOcPlumiyNmzUqtwGhO+9ad1W5BqJaRI6P
+YfouNkwR6Na4TzSj5BrqUfP0FwDizKSJ0XXmh8g8G9mtwxOSN3Ru1QFc61Xyeluk
+POGKBV/q6RBNklTNe0gI8usUMlYyoC7ytppNMW7X2vodAelSu25jgx2anj9fDVZu
+h7AXF5+4nJS4AAt0n1lNY7nGSsdZas8PbQIDAQABo4GIMIGFMA4GA1UdDwEB/wQE
+AwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1Ud
+DgQWBBStsdjh3/JCXXYlQryOrL4Sh7BW5TAuBgNVHREEJzAlggtleGFtcGxlLmNv
+bYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEAxWGI
+5NhpF3nwwy/4yB4i/CwwSpLrWUa70NyhvprUBC50PxiXav1TeDzwzLx/o5HyNwsv
+cxv3HdkLW59i/0SlJSrNnWdfZ19oTcS+6PtLoVyISgtyN6DpkKpdG1cOkW3Cy2P2
++tK/tKHRP1Y/Ra0RiDpOAmqn0gCOFGz8+lqDIor/T7MTpibL3IxqWfPrvfVRHL3B
+grw/ZQTTIVjjh4JBSW3WyWgNo/ikC1lrVxzl4iPUGptxT36Cr7Zk2Bsg0XqwbOvK
+5d+NTDREkSnUbie4GeutujmX3Dsx88UiV6UY/4lHJa6I5leHUNOHahRbpbWeOfs/
+WkBKOclmOV2xlTVuPw==
-----END CERTIFICATE-----`)
// LocalhostKey is the private key for LocalhostCert.
var LocalhostKey = []byte(testingKey(`-----BEGIN RSA TESTING KEY-----
-MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
-SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
-l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
-AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
-3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
-uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
-qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
-jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
-fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
-fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
-y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
-qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
-f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
+MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDoZtrm0dXV0Aqi
+4Bpc7f95sNRTiu/AJSD8I1onY9PnEsPg3VVxvytsVJbYdcqr4w99V3AgpH/UNzMS
+gAZ/8lZBNbsSDOVesJ3euVqMRfYPvd9pYl6QPRRpSDPm+2tNdn3QFAvta9EgJ3sW
+URnoU85w+W6aLI2bNSq3AaE771p3VbkGolpEjo9h+i42TBHo1rhPNKPkGupR8/QX
+AOLMpInRdeaHyDwb2a3DE5I3dG7VAVzrVfJ6W6Q84YoFX+rpEE2SVM17SAjy6xQy
+VjKgLvK2mk0xbtfa+h0B6VK7bmODHZqeP18NVm6HsBcXn7iclLgAC3SfWU1jucZK
+x1lqzw9tAgMBAAECggEABWzxS1Y2wckblnXY57Z+sl6YdmLV+gxj2r8Qib7g4ZIk
+lIlWR1OJNfw7kU4eryib4fc6nOh6O4AWZyYqAK6tqNQSS/eVG0LQTLTTEldHyVJL
+dvBe+MsUQOj4nTndZW+QvFzbcm2D8lY5n2nBSxU5ypVoKZ1EqQzytFcLZpTN7d89
+EPj0qDyrV4NZlWAwL1AygCwnlwhMQjXEalVF1ylXwU3QzyZ/6MgvF6d3SSUlh+sq
+XefuyigXw484cQQgbzopv6niMOmGP3of+yV4JQqUSb3IDmmT68XjGd2Dkxl4iPki
+6ZwXf3CCi+c+i/zVEcufgZ3SLf8D99kUGE7v7fZ6AQKBgQD1ZX3RAla9hIhxCf+O
+3D+I1j2LMrdjAh0ZKKqwMR4JnHX3mjQI6LwqIctPWTU8wYFECSh9klEclSdCa64s
+uI/GNpcqPXejd0cAAdqHEEeG5sHMDt0oFSurL4lyud0GtZvwlzLuwEweuDtvT9cJ
+Wfvl86uyO36IW8JdvUprYDctrQKBgQDycZ697qutBieZlGkHpnYWUAeImVA878sJ
+w44NuXHvMxBPz+lbJGAg8Cn8fcxNAPqHIraK+kx3po8cZGQywKHUWsxi23ozHoxo
++bGqeQb9U661TnfdDspIXia+xilZt3mm5BPzOUuRqlh4Y9SOBpSWRmEhyw76w4ZP
+OPxjWYAgwQKBgA/FehSYxeJgRjSdo+MWnK66tjHgDJE8bYpUZsP0JC4R9DL5oiaA
+brd2fI6Y+SbyeNBallObt8LSgzdtnEAbjIH8uDJqyOmknNePRvAvR6mP4xyuR+Bv
+m+Lgp0DMWTw5J9CKpydZDItc49T/mJ5tPhdFVd+am0NAQnmr1MCZ6nHxAoGABS3Y
+LkaC9FdFUUqSU8+Chkd/YbOkuyiENdkvl6t2e52jo5DVc1T7mLiIrRQi4SI8N9bN
+/3oJWCT+uaSLX2ouCtNFunblzWHBrhxnZzTeqVq4SLc8aESAnbslKL4i8/+vYZlN
+s8xtiNcSvL+lMsOBORSXzpj/4Ot8WwTkn1qyGgECgYBKNTypzAHeLE6yVadFp3nQ
+Ckq9yzvP/ib05rvgbvrne00YeOxqJ9gtTrzgh7koqJyX1L4NwdkEza4ilDWpucn0
+xiUZS4SoaJq6ZvcBYS62Yr1t8n09iG47YL8ibgtmH3L+svaotvpVxVK+d7BLevA/
+ZboOWVe3icTy64BT3OQhmg==
-----END RSA TESTING KEY-----`))
func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go
index 6564627..27872b4 100644
--- a/libgo/go/net/http/main_test.go
+++ b/libgo/go/net/http/main_test.go
@@ -31,11 +31,8 @@ func interestingGoroutines() (gs []string) {
buf := make([]byte, 2<<20)
buf = buf[:runtime.Stack(buf, true)]
for _, g := range strings.Split(string(buf), "\n\n") {
- sl := strings.SplitN(g, "\n", 2)
- if len(sl) != 2 {
- continue
- }
- stack := strings.TrimSpace(sl[1])
+ _, stack, _ := strings.Cut(g, "\n")
+ stack = strings.TrimSpace(stack)
if stack == "" ||
strings.Contains(stack, "testing.(*M).before.func1") ||
strings.Contains(stack, "os/signal.signal_recv") ||
@@ -46,7 +43,7 @@ func interestingGoroutines() (gs []string) {
// These only show up with GOTRACEBACK=2; Issue 5005 (comment 28)
strings.Contains(stack, "runtime.goexit") ||
strings.Contains(stack, "created by runtime.gc") ||
- strings.Contains(stack, "net/http_test.interestingGoroutines") ||
+ strings.Contains(stack, "interestingGoroutines") ||
strings.Contains(stack, "runtime.MHeap_Scavenger") {
continue
}
diff --git a/libgo/go/net/http/omithttp2.go b/libgo/go/net/http/omithttp2.go
index 79599d0..3316f55 100644
--- a/libgo/go/net/http/omithttp2.go
+++ b/libgo/go/net/http/omithttp2.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build nethttpomithttp2
-// +build nethttpomithttp2
package http
@@ -27,7 +26,7 @@ const http2NextProtoTLS = "h2"
type http2Transport struct {
MaxHeaderListSize uint32
- ConnPool interface{}
+ ConnPool any
}
func (*http2Transport) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) }
@@ -57,9 +56,9 @@ type http2Server struct {
NewWriteScheduler func() http2WriteScheduler
}
-type http2WriteScheduler interface{}
+type http2WriteScheduler any
-func http2NewPriorityWriteScheduler(interface{}) http2WriteScheduler { panic(noHTTP2) }
+func http2NewPriorityWriteScheduler(any) http2WriteScheduler { panic(noHTTP2) }
func http2ConfigureServer(s *Server, conf *http2Server) error { panic(noHTTP2) }
diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go
index 888ea35..dc855c8 100644
--- a/libgo/go/net/http/pprof/pprof.go
+++ b/libgo/go/net/http/pprof/pprof.go
@@ -44,7 +44,7 @@
// The package also exports a handler that serves execution trace data
// for the "go tool trace" command. To collect a 5-second execution trace:
//
-// wget -O trace.out http://localhost:6060/debug/pprof/trace?seconds=5
+// curl -o trace.out http://localhost:6060/debug/pprof/trace?seconds=5
// go tool trace trace.out
//
// To view all available profiles, open http://localhost:6060/debug/pprof/
diff --git a/libgo/go/net/http/pprof/pprof_test.go b/libgo/go/net/http/pprof/pprof_test.go
index 84757e4..1a4d653 100644
--- a/libgo/go/net/http/pprof/pprof_test.go
+++ b/libgo/go/net/http/pprof/pprof_test.go
@@ -8,6 +8,7 @@ import (
"bytes"
"fmt"
"internal/profile"
+ "internal/testenv"
"io"
"net/http"
"net/http/httptest"
@@ -152,6 +153,10 @@ func mutexHog(duration time.Duration, hogger func(mu1, mu2 *sync.Mutex, start ti
}
func TestDeltaProfile(t *testing.T) {
+ if runtime.GOOS == "openbsd" && runtime.GOARCH == "arm" {
+ testenv.SkipFlaky(t, 50218)
+ }
+
rate := runtime.SetMutexProfileFraction(1)
defer func() {
runtime.SetMutexProfileFraction(rate)
diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go
index 09cb0c7..76c2317 100644
--- a/libgo/go/net/http/request.go
+++ b/libgo/go/net/http/request.go
@@ -779,11 +779,10 @@ func removeZone(host string) string {
return host[:j] + host[i:]
}
-// ParseHTTPVersion parses an HTTP version string.
+// ParseHTTPVersion parses an HTTP version string according to RFC 7230, section 2.6.
// "HTTP/1.0" returns (1, 0, true). Note that strings without
// a minor version, such as "HTTP/2", are not valid.
func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
- const Big = 1000000 // arbitrary upper bound
switch vers {
case "HTTP/1.1":
return 1, 1, true
@@ -793,19 +792,21 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
if !strings.HasPrefix(vers, "HTTP/") {
return 0, 0, false
}
- dot := strings.Index(vers, ".")
- if dot < 0 {
+ if len(vers) != len("HTTP/X.Y") {
return 0, 0, false
}
- major, err := strconv.Atoi(vers[5:dot])
- if err != nil || major < 0 || major > Big {
+ if vers[6] != '.' {
return 0, 0, false
}
- minor, err = strconv.Atoi(vers[dot+1:])
- if err != nil || minor < 0 || minor > Big {
+ maj, err := strconv.ParseUint(vers[5:6], 10, 0)
+ if err != nil {
return 0, 0, false
}
- return major, minor, true
+ min, err := strconv.ParseUint(vers[7:8], 10, 0)
+ if err != nil {
+ return 0, 0, false
+ }
+ return int(maj), int(min), true
}
func validMethod(method string) bool {
@@ -939,7 +940,7 @@ func NewRequestWithContext(ctx context.Context, method, url string, body io.Read
func (r *Request) BasicAuth() (username, password string, ok bool) {
auth := r.Header.Get("Authorization")
if auth == "" {
- return
+ return "", "", false
}
return parseBasicAuth(auth)
}
@@ -950,18 +951,18 @@ func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
if len(auth) < len(prefix) || !ascii.EqualFold(auth[:len(prefix)], prefix) {
- return
+ return "", "", false
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
if err != nil {
- return
+ return "", "", false
}
cs := string(c)
- s := strings.IndexByte(cs, ':')
- if s < 0 {
- return
+ username, password, ok = strings.Cut(cs, ":")
+ if !ok {
+ return "", "", false
}
- return cs[:s], cs[s+1:], true
+ return username, password, true
}
// SetBasicAuth sets the request's Authorization header to use HTTP
@@ -979,13 +980,12 @@ func (r *Request) SetBasicAuth(username, password string) {
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
- s1 := strings.Index(line, " ")
- s2 := strings.Index(line[s1+1:], " ")
- if s1 < 0 || s2 < 0 {
- return
+ method, rest, ok1 := strings.Cut(line, " ")
+ requestURI, proto, ok2 := strings.Cut(rest, " ")
+ if !ok1 || !ok2 {
+ return "", "", "", false
}
- s2 += s1 + 1
- return line[:s1], line[s1+1 : s2], line[s2+1:], true
+ return method, requestURI, proto, true
}
var textprotoReaderPool sync.Pool
diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go
index 4e0c4ba..4363e11 100644
--- a/libgo/go/net/http/request_test.go
+++ b/libgo/go/net/http/request_test.go
@@ -639,10 +639,10 @@ var parseHTTPVersionTests = []struct {
major, minor int
ok bool
}{
+ {"HTTP/0.0", 0, 0, true},
{"HTTP/0.9", 0, 9, true},
{"HTTP/1.0", 1, 0, true},
{"HTTP/1.1", 1, 1, true},
- {"HTTP/3.14", 3, 14, true},
{"HTTP", 0, 0, false},
{"HTTP/one.one", 0, 0, false},
@@ -651,6 +651,12 @@ var parseHTTPVersionTests = []struct {
{"HTTP/0,-1", 0, 0, false},
{"HTTP/", 0, 0, false},
{"HTTP/1,1", 0, 0, false},
+ {"HTTP/+1.1", 0, 0, false},
+ {"HTTP/1.+1", 0, 0, false},
+ {"HTTP/0000000001.1", 0, 0, false},
+ {"HTTP/1.0000000001", 0, 0, false},
+ {"HTTP/3.14", 0, 0, false},
+ {"HTTP/12.3", 0, 0, false},
}
func TestParseHTTPVersion(t *testing.T) {
diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go
index 1157bdf..bdc1e3c 100644
--- a/libgo/go/net/http/requestwrite_test.go
+++ b/libgo/go/net/http/requestwrite_test.go
@@ -20,7 +20,7 @@ import (
type reqWriteTest struct {
Req Request
- Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body
+ Body any // optional []byte or func() io.ReadCloser to populate Req.Body
// Any of these three may be empty to skip that test.
WantWrite string // Request.Write
diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go
index b8985da..297394e 100644
--- a/libgo/go/net/http/response.go
+++ b/libgo/go/net/http/response.go
@@ -165,16 +165,14 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) {
}
return nil, err
}
- if i := strings.IndexByte(line, ' '); i == -1 {
+ proto, status, ok := strings.Cut(line, " ")
+ if !ok {
return nil, badStringError("malformed HTTP response", line)
- } else {
- resp.Proto = line[:i]
- resp.Status = strings.TrimLeft(line[i+1:], " ")
- }
- statusCode := resp.Status
- if i := strings.IndexByte(resp.Status, ' '); i != -1 {
- statusCode = resp.Status[:i]
}
+ resp.Proto = proto
+ resp.Status = strings.TrimLeft(status, " ")
+
+ statusCode, _, _ := strings.Cut(resp.Status, " ")
if len(statusCode) != 3 {
return nil, badStringError("malformed HTTP status code", statusCode)
}
@@ -182,7 +180,6 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) {
if err != nil || resp.StatusCode < 0 {
return nil, badStringError("malformed HTTP status code", statusCode)
}
- var ok bool
if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok {
return nil, badStringError("malformed HTTP version", resp.Proto)
}
diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go
index 8eef654..5a735b0 100644
--- a/libgo/go/net/http/response_test.go
+++ b/libgo/go/net/http/response_test.go
@@ -646,8 +646,8 @@ type readerAndCloser struct {
func TestReadResponseCloseInMiddle(t *testing.T) {
t.Parallel()
for _, test := range readResponseCloseInMiddleTests {
- fatalf := func(format string, args ...interface{}) {
- args = append([]interface{}{test.chunked, test.compressed}, args...)
+ fatalf := func(format string, args ...any) {
+ args = append([]any{test.chunked, test.compressed}, args...)
t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...)
}
checkErr := func(err error, msg string) {
@@ -732,7 +732,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) {
}
}
-func diff(t *testing.T, prefix string, have, want interface{}) {
+func diff(t *testing.T, prefix string, have, want any) {
t.Helper()
hv := reflect.ValueOf(have).Elem()
wv := reflect.ValueOf(want).Elem()
@@ -849,10 +849,10 @@ func TestReadResponseErrors(t *testing.T) {
type testCase struct {
name string // optional, defaults to in
in string
- wantErr interface{} // nil, err value, or string substring
+ wantErr any // nil, err value, or string substring
}
- status := func(s string, wantErr interface{}) testCase {
+ status := func(s string, wantErr any) testCase {
if wantErr == true {
wantErr = "malformed HTTP status code"
}
@@ -863,7 +863,7 @@ func TestReadResponseErrors(t *testing.T) {
}
}
- version := func(s string, wantErr interface{}) testCase {
+ version := func(s string, wantErr any) testCase {
if wantErr == true {
wantErr = "malformed HTTP version"
}
@@ -874,7 +874,7 @@ func TestReadResponseErrors(t *testing.T) {
}
}
- contentLength := func(status, body string, wantErr interface{}) testCase {
+ contentLength := func(status, body string, wantErr any) testCase {
return testCase{
name: fmt.Sprintf("status %q %q", status, body),
in: fmt.Sprintf("HTTP/1.1 %s\r\n%s", status, body),
@@ -947,7 +947,7 @@ func TestReadResponseErrors(t *testing.T) {
// wantErr can be nil, an error value to match exactly, or type string to
// match a substring.
-func matchErr(err error, wantErr interface{}) error {
+func matchErr(err error, wantErr any) error {
if err == nil {
if wantErr == nil {
return nil
diff --git a/libgo/go/net/http/roundtrip.go b/libgo/go/net/http/roundtrip.go
index eef7c79..c4c5d3b 100644
--- a/libgo/go/net/http/roundtrip.go
+++ b/libgo/go/net/http/roundtrip.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js || !wasm
-// +build !js !wasm
package http
diff --git a/libgo/go/net/http/roundtrip_js.go b/libgo/go/net/http/roundtrip_js.go
index 74c83a9..01c0600 100644
--- a/libgo/go/net/http/roundtrip_js.go
+++ b/libgo/go/net/http/roundtrip_js.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build js && wasm
-// +build js,wasm
package http
@@ -41,11 +40,19 @@ const jsFetchCreds = "js.fetch:credentials"
// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters
const jsFetchRedirect = "js.fetch:redirect"
-var useFakeNetwork = js.Global().Get("fetch").IsUndefined()
+// jsFetchMissing will be true if the Fetch API is not present in
+// the browser globals.
+var jsFetchMissing = js.Global().Get("fetch").IsUndefined()
// RoundTrip implements the RoundTripper interface using the WHATWG Fetch API.
func (t *Transport) RoundTrip(req *Request) (*Response, error) {
- if useFakeNetwork {
+ // The Transport has a documented contract that states that if the DialContext or
+ // DialTLSContext functions are set, they will be used to set up the connections.
+ // If they aren't set then the documented contract is to use Dial or DialTLS, even
+ // though they are deprecated. Therefore, if any of these are set, we should obey
+ // the contract and dial using the regular round-trip instead. Otherwise, we'll try
+ // to fall back on the Fetch API, unless it's not available.
+ if t.Dial != nil || t.DialContext != nil || t.DialTLS != nil || t.DialTLSContext != nil || jsFetchMissing {
return t.roundTrip(req)
}
@@ -111,7 +118,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
errCh = make(chan error, 1)
success, failure js.Func
)
- success = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
+ success = js.FuncOf(func(this js.Value, args []js.Value) any {
success.Release()
failure.Release()
@@ -131,8 +138,24 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
}
contentLength := int64(0)
- if cl, err := strconv.ParseInt(header.Get("Content-Length"), 10, 64); err == nil {
+ clHeader := header.Get("Content-Length")
+ switch {
+ case clHeader != "":
+ cl, err := strconv.ParseInt(clHeader, 10, 64)
+ if err != nil {
+ errCh <- fmt.Errorf("net/http: ill-formed Content-Length header: %v", err)
+ return nil
+ }
+ if cl < 0 {
+ // Content-Length values less than 0 are invalid.
+ // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13
+ errCh <- fmt.Errorf("net/http: invalid Content-Length header: %q", clHeader)
+ return nil
+ }
contentLength = cl
+ default:
+ // If the response length is not declared, set it to -1.
+ contentLength = -1
}
b := result.Get("body")
@@ -159,7 +182,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
return nil
})
- failure = js.FuncOf(func(this js.Value, args []js.Value) interface{} {
+ failure = js.FuncOf(func(this js.Value, args []js.Value) any {
success.Release()
failure.Release()
errCh <- fmt.Errorf("net/http: fetch() failed: %s", args[0].Get("message").String())
@@ -200,7 +223,7 @@ func (r *streamReader) Read(p []byte) (n int, err error) {
bCh = make(chan []byte, 1)
errCh = make(chan error, 1)
)
- success := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
+ success := js.FuncOf(func(this js.Value, args []js.Value) any {
result := args[0]
if result.Get("done").Bool() {
errCh <- io.EOF
@@ -212,7 +235,7 @@ func (r *streamReader) Read(p []byte) (n int, err error) {
return nil
})
defer success.Release()
- failure := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
+ failure := js.FuncOf(func(this js.Value, args []js.Value) any {
// Assumes it's a TypeError. See
// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError
// for more information on this type. See
@@ -266,7 +289,7 @@ func (r *arrayReader) Read(p []byte) (n int, err error) {
bCh = make(chan []byte, 1)
errCh = make(chan error, 1)
)
- success := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
+ success := js.FuncOf(func(this js.Value, args []js.Value) any {
// Wrap the input ArrayBuffer with a Uint8Array
uint8arrayWrapper := uint8Array.New(args[0])
value := make([]byte, uint8arrayWrapper.Get("byteLength").Int())
@@ -275,7 +298,7 @@ func (r *arrayReader) Read(p []byte) (n int, err error) {
return nil
})
defer success.Release()
- failure := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
+ failure := js.FuncOf(func(this js.Value, args []js.Value) any {
// Assumes it's a TypeError. See
// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError
// for more information on this type.
diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go
index 6394da3..fb18cb2 100644
--- a/libgo/go/net/http/serve_test.go
+++ b/libgo/go/net/http/serve_test.go
@@ -23,6 +23,7 @@ import (
"net"
. "net/http"
"net/http/httptest"
+ "net/http/httptrace"
"net/http/httputil"
"net/http/internal"
"net/http/internal/testcert"
@@ -2146,7 +2147,7 @@ func TestInvalidTrailerClosesConnection(t *testing.T) {
// Read and Write.
type slowTestConn struct {
// over multiple calls to Read, time.Durations are slept, strings are read.
- script []interface{}
+ script []any
closec chan bool
mu sync.Mutex // guards rd/wd
@@ -2238,7 +2239,7 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
defer afterTest(t)
for _, handler := range testHandlerBodyConsumers {
conn := &slowTestConn{
- script: []interface{}{
+ script: []any{
"POST /public HTTP/1.1\r\n" +
"Host: test\r\n" +
"Content-Length: 10000\r\n" +
@@ -2273,6 +2274,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
}
}
+// cancelableTimeoutContext overwrites the error message to DeadlineExceeded
+type cancelableTimeoutContext struct {
+ context.Context
+}
+
+func (c cancelableTimeoutContext) Err() error {
+ if c.Context.Err() != nil {
+ return context.DeadlineExceeded
+ }
+ return nil
+}
+
func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
func testTimeoutHandler(t *testing.T, h2 bool) {
@@ -2285,8 +2298,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
- timeout := make(chan time.Time, 1) // write to this to force timeouts
- cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout))
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, h2, h)
defer cst.close()
// Succeed without timing out:
@@ -2307,7 +2321,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
}
// Times out:
- timeout <- time.Time{}
+ cancel()
+
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
@@ -2428,8 +2443,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
- timeout := make(chan time.Time, 1) // write to this to force timeouts
- cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, h1Mode, h)
defer cst.close()
// Succeed without timing out:
@@ -2450,7 +2466,8 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
}
// Times out:
- timeout <- time.Time{}
+ cancel()
+
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
@@ -2500,6 +2517,47 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
}
}
+func TestTimeoutHandlerContextCanceled(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ var err error
+ // The request context has already been canceled, but
+ // retry the write for a while to give the timeout handler
+ // a chance to notice.
+ for i := 0; i < 100; i++ {
+ _, err = w.Write([]byte("a"))
+ if err != nil {
+ break
+ }
+ time.Sleep(1 * time.Millisecond)
+ }
+ writeErrors <- err
+ })
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ h := NewTestTimeoutHandler(sayHi, ctx)
+ cst := newClientServerTest(t, h1Mode, h)
+ defer cst.close()
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := io.ReadAll(res.Body)
+ if g, e := string(body), ""; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g, e := <-writeErrors, context.Canceled; g != e {
+ t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
+ }
+}
+
// https://golang.org/issue/15948
func TestTimeoutHandlerEmptyResponse(t *testing.T) {
setParallel(t)
@@ -2708,7 +2766,7 @@ func TestHandlerPanicWithHijack(t *testing.T) {
testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing")
}
-func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue interface{}) {
+func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) {
defer afterTest(t)
// Unlike the other tests that set the log output to io.Discard
// to quiet the output, this test uses a pipe. The pipe serves three
@@ -3017,22 +3075,14 @@ func TestClientWriteShutdown(t *testing.T) {
if err != nil {
t.Fatalf("CloseWrite: %v", err)
}
- donec := make(chan bool)
- go func() {
- defer close(donec)
- bs, err := io.ReadAll(conn)
- if err != nil {
- t.Errorf("ReadAll: %v", err)
- }
- got := string(bs)
- if got != "" {
- t.Errorf("read %q from server; want nothing", got)
- }
- }()
- select {
- case <-donec:
- case <-time.After(10 * time.Second):
- t.Fatalf("timeout")
+
+ bs, err := io.ReadAll(conn)
+ if err != nil {
+ t.Errorf("ReadAll: %v", err)
+ }
+ got := string(bs)
+ if got != "" {
+ t.Errorf("read %q from server; want nothing", got)
}
}
@@ -3884,7 +3934,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
// this test fails, it hangs. This helps debugging and I've
// added this enough times "temporarily". It now gets added
// full time.
- errorf := func(format string, args ...interface{}) {
+ errorf := func(format string, args ...any) {
v := fmt.Sprintf(format, args...)
println(v)
t.Error(v)
@@ -3893,10 +3943,10 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
unblockBackend := make(chan bool)
backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
gone := rw.(CloseNotifier).CloseNotify()
- didCopy := make(chan interface{})
+ didCopy := make(chan any)
go func() {
n, err := io.CopyN(rw, req.Body, bodySize)
- didCopy <- []interface{}{n, err}
+ didCopy <- []any{n, err}
}()
isGone := false
Loop:
@@ -4888,7 +4938,7 @@ func TestServerContext_LocalAddrContextKey_h2(t *testing.T) {
func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) {
setParallel(t)
defer afterTest(t)
- ch := make(chan interface{}, 1)
+ ch := make(chan any, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- r.Context().Value(LocalAddrContextKey)
}))
@@ -5689,22 +5739,37 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) {
}
// Not parallel: messes with global variable. (http2goAwayTimeout)
defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
- fmt.Fprintf(w, "%v", r.RemoteAddr)
- }))
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer cst.close()
srv := cst.ts.Config
srv.SetKeepAlivesEnabled(false)
- a := cst.getURL(cst.ts.URL)
- if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
- t.Fatalf("test server has active conns")
- }
- b := cst.getURL(cst.ts.URL)
- if a == b {
- t.Errorf("got same connection between first and second requests")
- }
- if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
- t.Fatalf("test server has active conns")
+ for try := 0; try < 2; try++ {
+ if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
+ t.Fatalf("request %v: test server has active conns", try)
+ }
+ conns := 0
+ var info httptrace.GotConnInfo
+ ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ GotConn: func(v httptrace.GotConnInfo) {
+ conns++
+ info = v
+ },
+ })
+ req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if conns != 1 {
+ t.Fatalf("request %v: got %v conns, want 1", try, conns)
+ }
+ if info.Reused || info.WasIdle {
+ t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
+ }
}
}
@@ -5933,11 +5998,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
t.Fatal(err)
}
- select {
- case <-done:
- case <-time.After(2 * time.Second):
- t.Error("timeout")
- }
+ <-done
}
// Issue 18319: test that the Server validates the request method.
@@ -6232,7 +6293,7 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) {
// setting contentEncoding as an interface instead of a string
// directly, so as to differentiate between 3 states:
// unset, empty string "" and set string "foo/bar".
- contentEncoding interface{}
+ contentEncoding any
wantContentType string
}
@@ -6490,7 +6551,7 @@ func TestDisableKeepAliveUpgrade(t *testing.T) {
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
- t.Fatalf("Response.Body is not a io.ReadWriteCloser: %T", resp.Body)
+ t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
}
_, err = rwc.Write([]byte("hello"))
@@ -6609,3 +6670,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon
}
}
}
+
+func TestMaxBytesHandler(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ for _, maxSize := range []int64{100, 1_000, 1_000_000} {
+ for _, requestSize := range []int64{100, 1_000, 1_000_000} {
+ t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
+ func(t *testing.T) {
+ testMaxBytesHandler(t, maxSize, requestSize)
+ })
+ }
+ }
+}
+
+func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
+ var (
+ handlerN int64
+ handlerErr error
+ )
+ echo := HandlerFunc(func(w ResponseWriter, r *Request) {
+ var buf bytes.Buffer
+ handlerN, handlerErr = io.Copy(&buf, r.Body)
+ io.Copy(w, &buf)
+ })
+
+ ts := httptest.NewServer(MaxBytesHandler(echo, maxSize))
+ defer ts.Close()
+
+ c := ts.Client()
+ var buf strings.Builder
+ body := strings.NewReader(strings.Repeat("a", int(requestSize)))
+ res, err := c.Post(ts.URL, "text/plain", body)
+ if err != nil {
+ t.Errorf("unexpected connection error: %v", err)
+ } else {
+ _, err = io.Copy(&buf, res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Errorf("unexpected read error: %v", err)
+ }
+ }
+ if handlerN > maxSize {
+ t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
+ }
+ if requestSize > maxSize && handlerErr == nil {
+ t.Error("expected error on handler side; got nil")
+ }
+ if requestSize <= maxSize {
+ if handlerErr != nil {
+ t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
+ }
+ if handlerN != requestSize {
+ t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
+ }
+ }
+ if buf.Len() != int(handlerN) {
+ t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
+ }
+}
diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go
index ce39933..f5cdc3a 100644
--- a/libgo/go/net/http/server.go
+++ b/libgo/go/net/http/server.go
@@ -13,6 +13,7 @@ import (
"crypto/tls"
"errors"
"fmt"
+ "internal/godebug"
"io"
"log"
"math/rand"
@@ -20,7 +21,6 @@ import (
"net/textproto"
"net/url"
urlpkg "net/url"
- "os"
"path"
"runtime"
"sort"
@@ -494,8 +494,8 @@ type response struct {
// prior to the headers being written. If the set of trailers is fixed
// or known before the header is written, the normal Go trailers mechanism
// is preferred:
-// https://golang.org/pkg/net/http/#ResponseWriter
-// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
+// https://pkg.go.dev/net/http#ResponseWriter
+// https://pkg.go.dev/net/http#example-ResponseWriter-Trailers
const TrailerPrefix = "Trailer:"
// finalTrailers is called after the Handler exits and returns a non-nil
@@ -798,7 +798,7 @@ var (
)
var copyBufPool = sync.Pool{
- New: func() interface{} {
+ New: func() any {
b := make([]byte, 32*1024)
return &b
},
@@ -865,6 +865,28 @@ func (srv *Server) initialReadLimitSize() int64 {
return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
}
+// tlsHandshakeTimeout returns the time limit permitted for the TLS
+// handshake, or zero for unlimited.
+//
+// It returns the minimum of any positive ReadHeaderTimeout,
+// ReadTimeout, or WriteTimeout.
+func (srv *Server) tlsHandshakeTimeout() time.Duration {
+ var ret time.Duration
+ for _, v := range [...]time.Duration{
+ srv.ReadHeaderTimeout,
+ srv.ReadTimeout,
+ srv.WriteTimeout,
+ } {
+ if v <= 0 {
+ continue
+ }
+ if ret == 0 || v < ret {
+ ret = v
+ }
+ }
+ return ret
+}
+
// wrapper around io.ReadCloser which on first read, sends an
// HTTP/1.1 100 Continue header
type expectContinueReader struct {
@@ -1409,11 +1431,11 @@ func (cw *chunkWriter) writeHeader(p []byte) {
hasCL = false
}
- if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) {
- // do nothing
- } else if code == StatusNoContent {
+ if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) || code == StatusNoContent {
+ // Response has no body.
delHeader("Transfer-Encoding")
} else if hasCL {
+ // Content-Length has been provided, so no chunking is to be done.
delHeader("Transfer-Encoding")
} else if w.req.ProtoAtLeast(1, 1) {
// HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no
@@ -1424,6 +1446,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
if hasTE && te == "identity" {
cw.chunking = false
w.closeAfterReply = true
+ delHeader("Transfer-Encoding")
} else {
// HTTP/1.1 or greater: use chunked transfer encoding
// to avoid closing the connection at EOF.
@@ -1799,6 +1822,7 @@ func isCommonNetReadError(err error) bool {
func (c *conn) serve(ctx context.Context) {
c.remoteAddr = c.rwc.RemoteAddr().String()
ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
+ var inFlightResponse *response
defer func() {
if err := recover(); err != nil && err != ErrAbortHandler {
const size = 64 << 10
@@ -1806,18 +1830,25 @@ func (c *conn) serve(ctx context.Context) {
buf = buf[:runtime.Stack(buf, false)]
c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
}
+ if inFlightResponse != nil {
+ inFlightResponse.cancelCtx()
+ }
if !c.hijacked() {
+ if inFlightResponse != nil {
+ inFlightResponse.conn.r.abortPendingRead()
+ inFlightResponse.reqBody.Close()
+ }
c.close()
c.setState(c.rwc, StateClosed, runHooks)
}
}()
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
- if d := c.server.ReadTimeout; d > 0 {
- c.rwc.SetReadDeadline(time.Now().Add(d))
- }
- if d := c.server.WriteTimeout; d > 0 {
- c.rwc.SetWriteDeadline(time.Now().Add(d))
+ tlsTO := c.server.tlsHandshakeTimeout()
+ if tlsTO > 0 {
+ dl := time.Now().Add(tlsTO)
+ c.rwc.SetReadDeadline(dl)
+ c.rwc.SetWriteDeadline(dl)
}
if err := tlsConn.HandshakeContext(ctx); err != nil {
// If the handshake failed due to the client not speaking
@@ -1831,6 +1862,11 @@ func (c *conn) serve(ctx context.Context) {
c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
return
}
+ // Restore Conn-level deadlines.
+ if tlsTO > 0 {
+ c.rwc.SetReadDeadline(time.Time{})
+ c.rwc.SetWriteDeadline(time.Time{})
+ }
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()
if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) {
@@ -1931,7 +1967,9 @@ func (c *conn) serve(ctx context.Context) {
// in parallel even if their responses need to be serialized.
// But we're not going to implement HTTP pipelining because it
// was never deployed in the wild and the answer is HTTP/2.
+ inFlightResponse = w
serverHandler{c.server}.ServeHTTP(w, w.req)
+ inFlightResponse = nil
w.cancelCtx()
if c.hijacked() {
return
@@ -2277,7 +2315,7 @@ func cleanPath(p string) string {
// stripHostPort returns h without any trailing ":<port>".
func stripHostPort(h string) string {
// If no port on host, return unchanged
- if strings.IndexByte(h, ':') == -1 {
+ if !strings.Contains(h, ":") {
return h
}
host, _, err := net.SplitHostPort(h)
@@ -3157,7 +3195,7 @@ func (srv *Server) SetKeepAlivesEnabled(v bool) {
// TODO: Issue 26303: close HTTP/2 conns as soon as they become idle.
}
-func (s *Server) logf(format string, args ...interface{}) {
+func (s *Server) logf(format string, args ...any) {
if s.ErrorLog != nil {
s.ErrorLog.Printf(format, args...)
} else {
@@ -3168,7 +3206,7 @@ func (s *Server) logf(format string, args ...interface{}) {
// logf prints to the ErrorLog of the *Server associated with request r
// via ServerContextKey. If there's no associated server, or if ErrorLog
// is nil, logging is done via the log package's standard logger.
-func logf(r *Request, format string, args ...interface{}) {
+func logf(r *Request, format string, args ...any) {
s, _ := r.Context().Value(ServerContextKey).(*Server)
if s != nil && s.ErrorLog != nil {
s.ErrorLog.Printf(format, args...)
@@ -3264,7 +3302,7 @@ func (srv *Server) onceSetNextProtoDefaults_Serve() {
// configured otherwise. (by setting srv.TLSNextProto non-nil)
// It must only be called via srv.nextProtoOnce (use srv.setupHTTP2_*).
func (srv *Server) onceSetNextProtoDefaults() {
- if omitBundledHTTP2 || strings.Contains(os.Getenv("GODEBUG"), "http2server=0") {
+ if omitBundledHTTP2 || godebug.Get("http2server") == "0" {
return
}
// Enable HTTP/2 by default if the user hasn't otherwise
@@ -3331,7 +3369,7 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
h: make(Header),
req: r,
}
- panicChan := make(chan interface{}, 1)
+ panicChan := make(chan any, 1)
go func() {
defer func() {
if p := recover(); p != nil {
@@ -3359,9 +3397,15 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
case <-ctx.Done():
tw.mu.Lock()
defer tw.mu.Unlock()
- w.WriteHeader(StatusServiceUnavailable)
- io.WriteString(w, h.errorBody())
- tw.timedOut = true
+ switch err := ctx.Err(); err {
+ case context.DeadlineExceeded:
+ w.WriteHeader(StatusServiceUnavailable)
+ io.WriteString(w, h.errorBody())
+ tw.err = ErrHandlerTimeout
+ default:
+ w.WriteHeader(StatusServiceUnavailable)
+ tw.err = err
+ }
}
}
@@ -3372,7 +3416,7 @@ type timeoutWriter struct {
req *Request
mu sync.Mutex
- timedOut bool
+ err error
wroteHeader bool
code int
}
@@ -3392,8 +3436,8 @@ func (tw *timeoutWriter) Header() Header { return tw.h }
func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
- if tw.timedOut {
- return 0, ErrHandlerTimeout
+ if tw.err != nil {
+ return 0, tw.err
}
if !tw.wroteHeader {
tw.writeHeaderLocked(StatusOK)
@@ -3405,7 +3449,7 @@ func (tw *timeoutWriter) writeHeaderLocked(code int) {
checkWriteHeaderCode(code)
switch {
- case tw.timedOut:
+ case tw.err != nil:
return
case tw.wroteHeader:
if tw.req != nil {
@@ -3572,3 +3616,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
}
return false
}
+
+// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader.
+func MaxBytesHandler(h Handler, n int64) Handler {
+ return HandlerFunc(func(w ResponseWriter, r *Request) {
+ r2 := *r
+ r2.Body = MaxBytesReader(w, r.Body, n)
+ h.ServeHTTP(w, &r2)
+ })
+}
diff --git a/libgo/go/net/http/server_test.go b/libgo/go/net/http/server_test.go
index 0132f3b..d17c5c1 100644
--- a/libgo/go/net/http/server_test.go
+++ b/libgo/go/net/http/server_test.go
@@ -9,8 +9,61 @@ package http
import (
"fmt"
"testing"
+ "time"
)
+func TestServerTLSHandshakeTimeout(t *testing.T) {
+ tests := []struct {
+ s *Server
+ want time.Duration
+ }{
+ {
+ s: &Server{},
+ want: 0,
+ },
+ {
+ s: &Server{
+ ReadTimeout: -1,
+ },
+ want: 0,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ },
+ want: 5 * time.Second,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: -1,
+ },
+ want: 5 * time.Second,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 4 * time.Second,
+ },
+ want: 4 * time.Second,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ ReadHeaderTimeout: 2 * time.Second,
+ WriteTimeout: 4 * time.Second,
+ },
+ want: 2 * time.Second,
+ },
+ }
+ for i, tt := range tests {
+ got := tt.s.tlsHandshakeTimeout()
+ if got != tt.want {
+ t.Errorf("%d. got %v; want %v", i, got, tt.want)
+ }
+ }
+}
+
func BenchmarkServerMatch(b *testing.B) {
fn := func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "OK")
diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go
index 85c2e5a..6d51178 100644
--- a/libgo/go/net/http/transfer.go
+++ b/libgo/go/net/http/transfer.go
@@ -73,7 +73,7 @@ type transferWriter struct {
ByteReadCh chan readResult // non-nil if probeRequestBody called
}
-func newTransferWriter(r interface{}) (t *transferWriter, err error) {
+func newTransferWriter(r any) (t *transferWriter, err error) {
t = &transferWriter{}
// Extract relevant fields
@@ -212,6 +212,7 @@ func (t *transferWriter) probeRequestBody() {
rres.b = buf[0]
}
t.ByteReadCh <- rres
+ close(t.ByteReadCh)
}(t.Body)
timer := time.NewTimer(200 * time.Millisecond)
select {
@@ -480,7 +481,7 @@ func suppressedHeaders(status int) []string {
}
// msg is *Request or *Response.
-func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
+func readTransfer(msg any, r *bufio.Reader) (err error) {
t := &transferReader{RequestMethod: "GET"}
// Unify input
@@ -808,7 +809,7 @@ func fixTrailer(header Header, chunked bool) (Header, error) {
// and then reads the trailer if necessary.
type body struct {
src io.Reader
- hdr interface{} // non-nil (Response or Request) value means read trailer
+ hdr any // non-nil (Response or Request) value means read trailer
r *bufio.Reader // underlying wire-format reader for the trailer
closing bool // is the connection to be closed after reading body?
doEarlyClose bool // whether Close should stop early
@@ -1029,7 +1030,7 @@ func (b *body) registerOnHitEOF(fn func()) {
b.onHitEOF = fn
}
-// bodyLocked is a io.Reader reading from a *body when its mutex is
+// bodyLocked is an io.Reader reading from a *body when its mutex is
// already held.
type bodyLocked struct {
b *body
@@ -1072,6 +1073,9 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) {
if n == 1 {
p[0] = rres.b
}
+ if err == nil {
+ err = io.EOF
+ }
return
}
diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go
index 309194e..5fe3e6e 100644
--- a/libgo/go/net/http/transport.go
+++ b/libgo/go/net/http/transport.go
@@ -17,6 +17,7 @@ import (
"crypto/tls"
"errors"
"fmt"
+ "internal/godebug"
"io"
"log"
"net"
@@ -24,7 +25,6 @@ import (
"net/http/internal/ascii"
"net/textproto"
"net/url"
- "os"
"reflect"
"strings"
"sync"
@@ -42,10 +42,10 @@ import (
// $no_proxy) environment variables.
var DefaultTransport RoundTripper = &Transport{
Proxy: ProxyFromEnvironment,
- DialContext: (&net.Dialer{
+ DialContext: defaultTransportDialContext(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
- }).DialContext,
+ }),
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
@@ -360,7 +360,7 @@ func (t *Transport) hasCustomTLSDialer() bool {
// It must be called via t.nextProtoOnce.Do.
func (t *Transport) onceSetNextProtoDefaults() {
t.tlsNextProtoWasNil = (t.TLSNextProto == nil)
- if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") {
+ if godebug.Get("http2client") == "0" {
return
}
@@ -1715,12 +1715,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
return nil, err
}
if resp.StatusCode != 200 {
- f := strings.SplitN(resp.Status, " ", 2)
+ _, text, ok := strings.Cut(resp.Status, " ")
conn.Close()
- if len(f) < 2 {
+ if !ok {
return nil, errors.New("unknown status code")
}
- return nil, errors.New(f[1])
+ return nil, errors.New(text)
}
}
@@ -2481,7 +2481,7 @@ type requestAndChan struct {
callerGone <-chan struct{} // closed when roundTrip caller has returned
}
-// A writeRequest is sent by the readLoop's goroutine to the
+// A writeRequest is sent by the caller's goroutine to the
// writeLoop's goroutine to write a request while the read loop
// concurrently waits on both the write response and the server's
// reply.
@@ -2668,8 +2668,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
// a t.Logf func. See export_test.go's Request.WithT method.
type tLogKey struct{}
-func (tr *transportRequest) logf(format string, args ...interface{}) {
- if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok {
+func (tr *transportRequest) logf(format string, args ...any) {
+ if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...any)); ok {
logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...)
}
}
diff --git a/libgo/go/net/http/transport_default_js.go b/libgo/go/net/http/transport_default_js.go
new file mode 100644
index 0000000..c07d35e
--- /dev/null
+++ b/libgo/go/net/http/transport_default_js.go
@@ -0,0 +1,17 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build js && wasm
+// +build js,wasm
+
+package http
+
+import (
+ "context"
+ "net"
+)
+
+func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
+ return nil
+}
diff --git a/libgo/go/net/http/transport_default_other.go b/libgo/go/net/http/transport_default_other.go
new file mode 100644
index 0000000..8a2f1cc
--- /dev/null
+++ b/libgo/go/net/http/transport_default_other.go
@@ -0,0 +1,17 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !(js && wasm)
+// +build !js !wasm
+
+package http
+
+import (
+ "context"
+ "net"
+)
+
+func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
+ return dialer.DialContext
+}
diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go
index 7e14749..fed092b 100644
--- a/libgo/go/net/http/transport_test.go
+++ b/libgo/go/net/http/transport_test.go
@@ -776,7 +776,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) {
c := ts.Client()
fetch := func(n, retries int) string {
- condFatalf := func(format string, arg ...interface{}) {
+ condFatalf := func(format string, arg ...any) {
if retries <= 0 {
t.Fatalf(format, arg...)
}
@@ -3518,7 +3518,7 @@ func TestRetryRequestsOnError(t *testing.T) {
mu sync.Mutex
logbuf bytes.Buffer
)
- logf := func(format string, args ...interface{}) {
+ logf := func(format string, args ...any) {
mu.Lock()
defer mu.Unlock()
fmt.Fprintf(&logbuf, format, args...)
@@ -4495,7 +4495,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
var mu sync.Mutex // guards buf
var buf bytes.Buffer
- logf := func(format string, args ...interface{}) {
+ logf := func(format string, args ...any) {
mu.Lock()
defer mu.Unlock()
fmt.Fprintf(&buf, format, args...)
@@ -4654,7 +4654,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
func TestTransportEventTraceTLSVerify(t *testing.T) {
var mu sync.Mutex
var buf bytes.Buffer
- logf := func(format string, args ...interface{}) {
+ logf := func(format string, args ...any) {
mu.Lock()
defer mu.Unlock()
fmt.Fprintf(&buf, format, args...)
@@ -4740,7 +4740,7 @@ func TestTransportEventTraceRealDNS(t *testing.T) {
var mu sync.Mutex // guards buf
var buf bytes.Buffer
- logf := func(format string, args ...interface{}) {
+ logf := func(format string, args ...any) {
mu.Lock()
defer mu.Unlock()
fmt.Fprintf(&buf, format, args...)
@@ -6516,3 +6516,32 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) {
close(r2c)
wg.Wait()
}
+
+func TestHandlerAbortRacesBodyRead(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ go io.Copy(io.Discard, req.Body)
+ panic(ErrAbortHandler)
+ }))
+ defer ts.Close()
+
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 10; j++ {
+ const reqLen = 6 * 1024 * 1024
+ req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
+ req.ContentLength = reqLen
+ resp, _ := ts.Client().Transport.RoundTrip(req)
+ if resp != nil {
+ resp.Body.Close()
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
diff --git a/libgo/go/net/http/triv.go b/libgo/go/net/http/triv.go
index 4dc6240..11b19ab 100644
--- a/libgo/go/net/http/triv.go
+++ b/libgo/go/net/http/triv.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build ignore
-// +build ignore
package main
diff --git a/libgo/go/net/interface_aix.go b/libgo/go/net/interface_aix.go
index bd55386..ba15a67 100644
--- a/libgo/go/net/interface_aix.go
+++ b/libgo/go/net/interface_aix.go
@@ -78,7 +78,7 @@ func interfaceTable(ifindex int) ([]Interface, error) {
// Retrieve MTU
ifr := &ifreq{}
copy(ifr.Name[:], ifi.Name)
- err = unix.Ioctl(sock, syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(ifr)))
+ err = unix.Ioctl(sock, syscall.SIOCGIFMTU, unsafe.Pointer(ifr))
if err != nil {
return nil, err
}
diff --git a/libgo/go/net/interface_bsd.go b/libgo/go/net/interface_bsd.go
index 7578b1a..db7bc75 100644
--- a/libgo/go/net/interface_bsd.go
+++ b/libgo/go/net/interface_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package net
diff --git a/libgo/go/net/interface_bsd_test.go b/libgo/go/net/interface_bsd_test.go
index 8d0d9c3..ce59962 100644
--- a/libgo/go/net/interface_bsd_test.go
+++ b/libgo/go/net/interface_bsd_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package net
diff --git a/libgo/go/net/interface_bsdvar.go b/libgo/go/net/interface_bsdvar.go
index 6230e0b..e9bea3d 100644
--- a/libgo/go/net/interface_bsdvar.go
+++ b/libgo/go/net/interface_bsdvar.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build dragonfly || netbsd || openbsd
-// +build dragonfly netbsd openbsd
package net
diff --git a/libgo/go/net/interface_freebsd.go b/libgo/go/net/interface_freebsd.go
index 2b51fcb..8536bd3 100644
--- a/libgo/go/net/interface_freebsd.go
+++ b/libgo/go/net/interface_freebsd.go
@@ -11,16 +11,11 @@ import (
)
func interfaceMessages(ifindex int) ([]route.Message, error) {
- typ := route.RIBType(syscall.NET_RT_IFLISTL)
- rib, err := route.FetchRIB(syscall.AF_UNSPEC, typ, ifindex)
+ rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeInterface, ifindex)
if err != nil {
- typ = route.RIBType(syscall.NET_RT_IFLIST)
- rib, err = route.FetchRIB(syscall.AF_UNSPEC, typ, ifindex)
- if err != nil {
- return nil, err
- }
+ return nil, err
}
- return route.ParseRIB(typ, rib)
+ return route.ParseRIB(route.RIBTypeInterface, rib)
}
// interfaceMulticastAddrTable returns addresses for a specific
diff --git a/libgo/go/net/interface_stub.go b/libgo/go/net/interface_stub.go
index 1075e36..fadd8b2 100644
--- a/libgo/go/net/interface_stub.go
+++ b/libgo/go/net/interface_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build hurd || (js && wasm)
-// +build hurd js,wasm
package net
diff --git a/libgo/go/net/interface_test.go b/libgo/go/net/interface_test.go
index 754db36..f6c9868 100644
--- a/libgo/go/net/interface_test.go
+++ b/libgo/go/net/interface_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
diff --git a/libgo/go/net/interface_unix_test.go b/libgo/go/net/interface_unix_test.go
index 0d69fa5..92ec13a 100644
--- a/libgo/go/net/interface_unix_test.go
+++ b/libgo/go/net/interface_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd
-// +build darwin dragonfly freebsd linux netbsd openbsd
package net
diff --git a/libgo/go/net/internal/socktest/main_test.go b/libgo/go/net/internal/socktest/main_test.go
index 8af85d3..c7c8d16 100644
--- a/libgo/go/net/internal/socktest/main_test.go
+++ b/libgo/go/net/internal/socktest/main_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js && !plan9
-// +build !js,!plan9
package socktest_test
diff --git a/libgo/go/net/internal/socktest/main_unix_test.go b/libgo/go/net/internal/socktest/main_unix_test.go
index 6aa8875..7d21f6f 100644
--- a/libgo/go/net/internal/socktest/main_unix_test.go
+++ b/libgo/go/net/internal/socktest/main_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js && !plan9 && !windows
-// +build !js,!plan9,!windows
package socktest_test
diff --git a/libgo/go/net/internal/socktest/switch_posix.go b/libgo/go/net/internal/socktest/switch_posix.go
index cda74e8..fcad4ce 100644
--- a/libgo/go/net/internal/socktest/switch_posix.go
+++ b/libgo/go/net/internal/socktest/switch_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !plan9
-// +build !plan9
package socktest
diff --git a/libgo/go/net/internal/socktest/switch_stub.go b/libgo/go/net/internal/socktest/switch_stub.go
index 5aa2ece..8a2fc35 100644
--- a/libgo/go/net/internal/socktest/switch_stub.go
+++ b/libgo/go/net/internal/socktest/switch_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build plan9
-// +build plan9
package socktest
diff --git a/libgo/go/net/internal/socktest/switch_unix.go b/libgo/go/net/internal/socktest/switch_unix.go
index be9ef6d..83df596 100644
--- a/libgo/go/net/internal/socktest/switch_unix.go
+++ b/libgo/go/net/internal/socktest/switch_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris
package socktest
diff --git a/libgo/go/net/internal/socktest/sys_cloexec.go b/libgo/go/net/internal/socktest/sys_cloexec.go
index 5e95896..c2d9d4b 100644
--- a/libgo/go/net/internal/socktest/sys_cloexec.go
+++ b/libgo/go/net/internal/socktest/sys_cloexec.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || hurd || illumos || linux || netbsd || openbsd
-// +build dragonfly freebsd hurd illumos linux netbsd openbsd
package socktest
diff --git a/libgo/go/net/internal/socktest/sys_unix.go b/libgo/go/net/internal/socktest/sys_unix.go
index 39f3dbc..0cb4693 100644
--- a/libgo/go/net/internal/socktest/sys_unix.go
+++ b/libgo/go/net/internal/socktest/sys_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris
package socktest
diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go
index 38e1aa2..54c5288 100644
--- a/libgo/go/net/ip.go
+++ b/libgo/go/net/ip.go
@@ -308,7 +308,7 @@ func ubtoa(dst []byte, start int, v byte) int {
// It returns one of 4 forms:
// - "<nil>", if ip has length 0
// - dotted decimal ("192.0.2.1"), if ip is an IPv4 or IP4-mapped IPv6 address
-// - IPv6 ("2001:db8::1"), if ip is a valid IPv6 address
+// - IPv6 conforming to RFC 5952 ("2001:db8::1"), if ip is a valid IPv6 address
// - the hexadecimal form of ip, without punctuation, if no other cases apply
func (ip IP) String() string {
p := ip
@@ -545,6 +545,9 @@ func (n *IPNet) Network() string { return "ip+net" }
// character and a mask expressed as hexadecimal form with no
// punctuation like "198.51.100.0/c000ff00".
func (n *IPNet) String() string {
+ if n == nil {
+ return "<nil>"
+ }
nn, m := networkNumberAndMask(n)
if nn == nil || m == nil {
return "<nil>"
diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go
index 5bbda60..8f1590c 100644
--- a/libgo/go/net/ip_test.go
+++ b/libgo/go/net/ip_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -408,6 +407,7 @@ var ipNetStringTests = []struct {
{&IPNet{IP: IPv4(192, 168, 1, 0), Mask: IPv4Mask(255, 0, 255, 0)}, "192.168.1.0/ff00ff00"},
{&IPNet{IP: ParseIP("2001:db8::"), Mask: CIDRMask(55, 128)}, "2001:db8::/55"},
{&IPNet{IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("8000:f123:0:cafe::"))}, "2001:db8::/8000f1230000cafe0000000000000000"},
+ {nil, "<nil>"},
}
func TestIPNetString(t *testing.T) {
@@ -719,7 +719,7 @@ var ipAddrScopeTests = []struct {
{IP.IsPrivate, IP{0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false},
}
-func name(f interface{}) string {
+func name(f any) string {
return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
}
diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go
index ffc437c..04f8e10 100644
--- a/libgo/go/net/iprawsock_posix.go
+++ b/libgo/go/net/iprawsock_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/iprawsock_test.go b/libgo/go/net/iprawsock_test.go
index a96448e..ca5ab48 100644
--- a/libgo/go/net/iprawsock_test.go
+++ b/libgo/go/net/iprawsock_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
diff --git a/libgo/go/net/ipsock_posix.go b/libgo/go/net/ipsock_posix.go
index cdd191a..cec7eb7 100644
--- a/libgo/go/net/ipsock_posix.go
+++ b/libgo/go/net/ipsock_posix.go
@@ -3,13 +3,13 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows
package net
import (
"context"
"internal/poll"
+ "net/netip"
"runtime"
"syscall"
)
@@ -142,42 +142,87 @@ func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, soty
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn)
}
+func ipToSockaddrInet4(ip IP, port int) (syscall.SockaddrInet4, error) {
+ if len(ip) == 0 {
+ ip = IPv4zero
+ }
+ ip4 := ip.To4()
+ if ip4 == nil {
+ return syscall.SockaddrInet4{}, &AddrError{Err: "non-IPv4 address", Addr: ip.String()}
+ }
+ sa := syscall.SockaddrInet4{Port: port}
+ copy(sa.Addr[:], ip4)
+ return sa, nil
+}
+
+func ipToSockaddrInet6(ip IP, port int, zone string) (syscall.SockaddrInet6, error) {
+ // In general, an IP wildcard address, which is either
+ // "0.0.0.0" or "::", means the entire IP addressing
+ // space. For some historical reason, it is used to
+ // specify "any available address" on some operations
+ // of IP node.
+ //
+ // When the IP node supports IPv4-mapped IPv6 address,
+ // we allow a listener to listen to the wildcard
+ // address of both IP addressing spaces by specifying
+ // IPv6 wildcard address.
+ if len(ip) == 0 || ip.Equal(IPv4zero) {
+ ip = IPv6zero
+ }
+ // We accept any IPv6 address including IPv4-mapped
+ // IPv6 address.
+ ip6 := ip.To16()
+ if ip6 == nil {
+ return syscall.SockaddrInet6{}, &AddrError{Err: "non-IPv6 address", Addr: ip.String()}
+ }
+ sa := syscall.SockaddrInet6{Port: port, ZoneId: uint32(zoneCache.index(zone))}
+ copy(sa.Addr[:], ip6)
+ return sa, nil
+}
+
func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
switch family {
case syscall.AF_INET:
- if len(ip) == 0 {
- ip = IPv4zero
- }
- ip4 := ip.To4()
- if ip4 == nil {
- return nil, &AddrError{Err: "non-IPv4 address", Addr: ip.String()}
+ sa, err := ipToSockaddrInet4(ip, port)
+ if err != nil {
+ return nil, err
}
- sa := &syscall.SockaddrInet4{Port: port}
- copy(sa.Addr[:], ip4)
- return sa, nil
+ return &sa, nil
case syscall.AF_INET6:
- // In general, an IP wildcard address, which is either
- // "0.0.0.0" or "::", means the entire IP addressing
- // space. For some historical reason, it is used to
- // specify "any available address" on some operations
- // of IP node.
- //
- // When the IP node supports IPv4-mapped IPv6 address,
- // we allow a listener to listen to the wildcard
- // address of both IP addressing spaces by specifying
- // IPv6 wildcard address.
- if len(ip) == 0 || ip.Equal(IPv4zero) {
- ip = IPv6zero
- }
- // We accept any IPv6 address including IPv4-mapped
- // IPv6 address.
- ip6 := ip.To16()
- if ip6 == nil {
- return nil, &AddrError{Err: "non-IPv6 address", Addr: ip.String()}
+ sa, err := ipToSockaddrInet6(ip, port, zone)
+ if err != nil {
+ return nil, err
}
- sa := &syscall.SockaddrInet6{Port: port, ZoneId: uint32(zoneCache.index(zone))}
- copy(sa.Addr[:], ip6)
- return sa, nil
+ return &sa, nil
}
return nil, &AddrError{Err: "invalid address family", Addr: ip.String()}
}
+
+func addrPortToSockaddrInet4(ap netip.AddrPort) (syscall.SockaddrInet4, error) {
+ // ipToSockaddrInet4 has special handling here for zero length slices.
+ // We do not, because netip has no concept of a generic zero IP address.
+ addr := ap.Addr()
+ if !addr.Is4() {
+ return syscall.SockaddrInet4{}, &AddrError{Err: "non-IPv4 address", Addr: addr.String()}
+ }
+ sa := syscall.SockaddrInet4{
+ Addr: addr.As4(),
+ Port: int(ap.Port()),
+ }
+ return sa, nil
+}
+
+func addrPortToSockaddrInet6(ap netip.AddrPort) (syscall.SockaddrInet6, error) {
+ // ipToSockaddrInet6 has special handling here for zero length slices.
+ // We do not, because netip has no concept of a generic zero IP address.
+ addr := ap.Addr()
+ if !addr.Is6() {
+ return syscall.SockaddrInet6{}, &AddrError{Err: "non-IPv6 address", Addr: addr.String()}
+ }
+ sa := syscall.SockaddrInet6{
+ Addr: addr.As16(),
+ Port: int(ap.Port()),
+ ZoneId: uint32(zoneCache.index(addr.Zone())),
+ }
+ return sa, nil
+}
diff --git a/libgo/go/net/listen_test.go b/libgo/go/net/listen_test.go
index b1dce29..59c0112 100644
--- a/libgo/go/net/listen_test.go
+++ b/libgo/go/net/listen_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js && !plan9
-// +build !js,!plan9
package net
@@ -380,7 +379,7 @@ func differentWildcardAddr(i, j string) bool {
return true
}
-func checkFirstListener(network string, ln interface{}) error {
+func checkFirstListener(network string, ln any) error {
switch network {
case "tcp":
fd := ln.(*TCPListener).fd
@@ -535,8 +534,6 @@ func TestIPv4MulticastListener(t *testing.T) {
switch runtime.GOOS {
case "android", "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
- case "solaris", "illumos":
- t.Skipf("not supported on solaris or illumos, see golang.org/issue/7399")
}
if !supportsIPv4() {
t.Skip("IPv4 is not supported")
@@ -610,8 +607,6 @@ func TestIPv6MulticastListener(t *testing.T) {
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
- case "solaris", "illumos":
- t.Skipf("not supported on solaris or illumos, see issue 7399")
}
if !supportsIPv6() {
t.Skip("IPv6 is not supported")
@@ -702,10 +697,7 @@ func multicastRIBContains(ip IP) (bool, error) {
// Issue 21856.
func TestClosingListener(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
addr := ln.Addr()
go func() {
@@ -743,15 +735,13 @@ func TestListenConfigControl(t *testing.T) {
if !testableNetwork(network) {
continue
}
- ln, err := newLocalListener(network)
- if err != nil {
- t.Error(err)
- continue
- }
+ ln := newLocalListener(t, network)
address := ln.Addr().String()
+ // TODO: This is racy. The selected address could be reused in between
+ // this Close and the subsequent Listen.
ln.Close()
lc := ListenConfig{Control: controlOnConnSetup}
- ln, err = lc.Listen(context.Background(), network, address)
+ ln, err := lc.Listen(context.Background(), network, address)
if err != nil {
t.Error(err)
continue
@@ -764,18 +754,16 @@ func TestListenConfigControl(t *testing.T) {
if !testableNetwork(network) {
continue
}
- c, err := newLocalPacketListener(network)
- if err != nil {
- t.Error(err)
- continue
- }
+ c := newLocalPacketListener(t, network)
address := c.LocalAddr().String()
+ // TODO: This is racy. The selected address could be reused in between
+ // this Close and the subsequent ListenPacket.
c.Close()
if network == "unixgram" {
os.Remove(address)
}
lc := ListenConfig{Control: controlOnConnSetup}
- c, err = lc.ListenPacket(context.Background(), network, address)
+ c, err := lc.ListenPacket(context.Background(), network, address)
if err != nil {
t.Error(err)
continue
diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go
index d350ef7..c7b8dc6 100644
--- a/libgo/go/net/lookup.go
+++ b/libgo/go/net/lookup.go
@@ -8,6 +8,7 @@ import (
"context"
"internal/nettrace"
"internal/singleflight"
+ "net/netip"
"sync"
)
@@ -232,6 +233,28 @@ func (r *Resolver) LookupIP(ctx context.Context, network, host string) ([]IP, er
return ips, nil
}
+// LookupNetIP looks up host using the local resolver.
+// It returns a slice of that host's IP addresses of the type specified by
+// network.
+// The network must be one of "ip", "ip4" or "ip6".
+func (r *Resolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
+ // TODO(bradfitz): make this efficient, making the internal net package
+ // type throughout be netip.Addr and only converting to the net.IP slice
+ // version at the edge. But for now (2021-10-20), this is a wrapper around
+ // the old way.
+ ips, err := r.LookupIP(ctx, network, host)
+ if err != nil {
+ return nil, err
+ }
+ ret := make([]netip.Addr, 0, len(ips))
+ for _, ip := range ips {
+ if a, ok := netip.AddrFromSlice(ip); ok {
+ ret = append(ret, a)
+ }
+ }
+ return ret, nil
+}
+
// onlyValuesCtx is a context that uses an underlying context
// for value lookup if the underlying context hasn't yet expired.
type onlyValuesCtx struct {
@@ -242,7 +265,7 @@ type onlyValuesCtx struct {
var _ context.Context = (*onlyValuesCtx)(nil)
// Value performs a lookup if the original context hasn't expired.
-func (ovc *onlyValuesCtx) Value(key interface{}) interface{} {
+func (ovc *onlyValuesCtx) Value(key any) any {
select {
case <-ovc.lookupValues.Done():
return nil
@@ -291,7 +314,7 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP
lookupKey := network + "\000" + host
dnsWaitGroup.Add(1)
- ch, called := r.getLookupGroup().DoChan(lookupKey, func() (interface{}, error) {
+ ch, called := r.getLookupGroup().DoChan(lookupKey, func() (any, error) {
defer dnsWaitGroup.Done()
return testHookLookupIP(lookupGroupCtx, resolverFunc, network, host)
})
@@ -316,24 +339,45 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP
lookupGroupCancel()
}()
}
- err := mapErr(ctx.Err())
+ ctxErr := ctx.Err()
+ err := &DNSError{
+ Err: mapErr(ctxErr).Error(),
+ Name: host,
+ IsTimeout: ctxErr == context.DeadlineExceeded,
+ }
if trace != nil && trace.DNSDone != nil {
trace.DNSDone(nil, false, err)
}
return nil, err
case r := <-ch:
lookupGroupCancel()
+ err := r.Err
+ if err != nil {
+ if _, ok := err.(*DNSError); !ok {
+ isTimeout := false
+ if err == context.DeadlineExceeded {
+ isTimeout = true
+ } else if terr, ok := err.(timeout); ok {
+ isTimeout = terr.Timeout()
+ }
+ err = &DNSError{
+ Err: err.Error(),
+ Name: host,
+ IsTimeout: isTimeout,
+ }
+ }
+ }
if trace != nil && trace.DNSDone != nil {
addrs, _ := r.Val.([]IPAddr)
- trace.DNSDone(ipAddrsEface(addrs), r.Shared, r.Err)
+ trace.DNSDone(ipAddrsEface(addrs), r.Shared, err)
}
- return lookupIPReturn(r.Val, r.Err, r.Shared)
+ return lookupIPReturn(r.Val, err, r.Shared)
}
}
// lookupIPReturn turns the return values from singleflight.Do into
// the return values from LookupIP.
-func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) {
+func lookupIPReturn(addrsi any, err error, shared bool) ([]IPAddr, error) {
if err != nil {
return nil, err
}
@@ -347,8 +391,8 @@ func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error
}
// ipAddrsEface returns an empty interface slice of addrs.
-func ipAddrsEface(addrs []IPAddr) []interface{} {
- s := make([]interface{}, len(addrs))
+func ipAddrsEface(addrs []IPAddr) []any {
+ s := make([]any, len(addrs))
for i, v := range addrs {
s[i] = v
}
@@ -442,7 +486,7 @@ func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error)
// The returned service names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
-// will be returned alongside the the remaining results, if any.
+// will be returned alongside the remaining results, if any.
func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
return DefaultResolver.LookupSRV(context.Background(), service, proto, name)
}
@@ -460,7 +504,7 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
// The returned service names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
-// will be returned alongside the the remaining results, if any.
+// will be returned alongside the remaining results, if any.
func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
cname, addrs, err := r.lookupSRV(ctx, service, proto, name)
if err != nil {
@@ -490,7 +534,7 @@ func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (
// The returned mail server names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
-// will be returned alongside the the remaining results, if any.
+// will be returned alongside the remaining results, if any.
//
// LookupMX uses context.Background internally; to specify the context, use
// Resolver.LookupMX.
@@ -503,7 +547,7 @@ func LookupMX(name string) ([]*MX, error) {
// The returned mail server names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
-// will be returned alongside the the remaining results, if any.
+// will be returned alongside the remaining results, if any.
func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
records, err := r.lookupMX(ctx, name)
if err != nil {
@@ -514,9 +558,7 @@ func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
if mx == nil {
continue
}
- // Bypass the hostname validity check for targets which contain only a dot,
- // as this is used to represent a 'Null' MX record.
- if mx.Host != "." && !isDomainName(mx.Host) {
+ if !isDomainName(mx.Host) {
continue
}
filteredMX = append(filteredMX, mx)
@@ -532,7 +574,7 @@ func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
// The returned name server names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
-// will be returned alongside the the remaining results, if any.
+// will be returned alongside the remaining results, if any.
//
// LookupNS uses context.Background internally; to specify the context, use
// Resolver.LookupNS.
@@ -545,7 +587,7 @@ func LookupNS(name string) ([]*NS, error) {
// The returned name server names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
-// will be returned alongside the the remaining results, if any.
+// will be returned alongside the remaining results, if any.
func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
records, err := r.lookupNS(ctx, name)
if err != nil {
@@ -585,7 +627,7 @@ func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error)
//
// The returned names are validated to be properly formatted presentation-format
// domain names. If the response contains invalid names, those records are filtered
-// out and an error will be returned alongside the the remaining results, if any.
+// out and an error will be returned alongside the remaining results, if any.
//
// When using the host C library resolver, at most one result will be
// returned. To bypass the host resolver, use a custom Resolver.
@@ -601,7 +643,7 @@ func LookupAddr(addr string) (names []string, err error) {
//
// The returned names are validated to be properly formatted presentation-format
// domain names. If the response contains invalid names, those records are filtered
-// out and an error will be returned alongside the the remaining results, if any.
+// out and an error will be returned alongside the remaining results, if any.
func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
names, err := r.lookupAddr(ctx, addr)
if err != nil {
@@ -620,6 +662,6 @@ func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error
}
// errMalformedDNSRecordsDetail is the DNSError detail which is returned when a Resolver.Lookup...
-// method recieves DNS records which contain invalid DNS names. This may be returned alongside
+// method receives DNS records which contain invalid DNS names. This may be returned alongside
// results which have had the malformed records filtered out.
var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names"
diff --git a/libgo/go/net/lookup_fake.go b/libgo/go/net/lookup_fake.go
index f4fcaed..c27eae4 100644
--- a/libgo/go/net/lookup_fake.go
+++ b/libgo/go/net/lookup_fake.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build js && wasm
-// +build js,wasm
package net
diff --git a/libgo/go/net/lookup_plan9.go b/libgo/go/net/lookup_plan9.go
index 75c18b3..d43a03b 100644
--- a/libgo/go/net/lookup_plan9.go
+++ b/libgo/go/net/lookup_plan9.go
@@ -262,8 +262,8 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cn
if !(portOk && priorityOk && weightOk) {
continue
}
- addrs = append(addrs, &SRV{absDomainName([]byte(f[5])), uint16(port), uint16(priority), uint16(weight)})
- cname = absDomainName([]byte(f[0]))
+ addrs = append(addrs, &SRV{absDomainName(f[5]), uint16(port), uint16(priority), uint16(weight)})
+ cname = absDomainName(f[0])
}
byPriorityWeight(addrs).sort()
return
@@ -280,7 +280,7 @@ func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error
continue
}
if pref, _, ok := dtoi(f[2]); ok {
- mx = append(mx, &MX{absDomainName([]byte(f[3])), uint16(pref)})
+ mx = append(mx, &MX{absDomainName(f[3]), uint16(pref)})
}
}
byPref(mx).sort()
@@ -297,7 +297,7 @@ func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error
if len(f) < 3 {
continue
}
- ns = append(ns, &NS{absDomainName([]byte(f[2]))})
+ ns = append(ns, &NS{absDomainName(f[2])})
}
return
}
@@ -329,7 +329,7 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, er
if len(f) < 3 {
continue
}
- name = append(name, absDomainName([]byte(f[2])))
+ name = append(name, absDomainName(f[2]))
}
return
}
diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go
index 3faaf00..063d650 100644
--- a/libgo/go/net/lookup_test.go
+++ b/libgo/go/net/lookup_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -353,6 +352,7 @@ var lookupCNAMETests = []struct {
func TestLookupCNAME(t *testing.T) {
mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
@@ -391,6 +391,7 @@ var lookupGoogleHostTests = []struct {
func TestLookupGoogleHost(t *testing.T) {
mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
@@ -443,6 +444,7 @@ var lookupGoogleIPTests = []struct {
func TestLookupGoogleIP(t *testing.T) {
mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
@@ -633,6 +635,7 @@ func TestLookupDotsWithRemoteSource(t *testing.T) {
testenv.SkipFlaky(t, 27992)
}
mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
@@ -657,7 +660,6 @@ func TestLookupDotsWithRemoteSource(t *testing.T) {
func testDots(t *testing.T, mode string) {
names, err := LookupAddr("8.8.8.8") // Google dns server
if err != nil {
- testenv.SkipFlakyNet(t)
t.Errorf("LookupAddr(8.8.8.8): %v (mode=%v)", err, mode)
} else {
for _, name := range names {
@@ -670,7 +672,6 @@ func testDots(t *testing.T, mode string) {
cname, err := LookupCNAME("www.mit.edu")
if err != nil {
- testenv.SkipFlakyNet(t)
t.Errorf("LookupCNAME(www.mit.edu, mode=%v): %v", mode, err)
} else if !strings.HasSuffix(cname, ".") {
t.Errorf("LookupCNAME(www.mit.edu) = %v, want cname ending in . with trailing dot (mode=%v)", cname, mode)
@@ -678,7 +679,6 @@ func testDots(t *testing.T, mode string) {
mxs, err := LookupMX("google.com")
if err != nil {
- testenv.SkipFlakyNet(t)
t.Errorf("LookupMX(google.com): %v (mode=%v)", err, mode)
} else {
for _, mx := range mxs {
@@ -691,7 +691,6 @@ func testDots(t *testing.T, mode string) {
nss, err := LookupNS("google.com")
if err != nil {
- testenv.SkipFlakyNet(t)
t.Errorf("LookupNS(google.com): %v (mode=%v)", err, mode)
} else {
for _, ns := range nss {
@@ -704,7 +703,6 @@ func testDots(t *testing.T, mode string) {
cname, srvs, err := LookupSRV("xmpp-server", "tcp", "google.com")
if err != nil {
- testenv.SkipFlakyNet(t)
t.Errorf("LookupSRV(xmpp-server, tcp, google.com): %v (mode=%v)", err, mode)
} else {
if !hasSuffixFold(cname, ".google.com.") {
@@ -890,7 +888,7 @@ func TestLookupContextCancel(t *testing.T) {
ctx, ctxCancel := context.WithCancel(context.Background())
ctxCancel()
_, err := DefaultResolver.LookupIPAddr(ctx, "google.com")
- if err != errCanceled {
+ if err.(*DNSError).Err != errCanceled.Error() {
testenv.SkipFlakyNet(t)
t.Fatal(err)
}
@@ -926,6 +924,9 @@ func TestNilResolverLookup(t *testing.T) {
// canceled lookups (see golang.org/issue/24178 for details).
func TestLookupHostCancel(t *testing.T) {
mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+ t.Parallel() // Executes 600ms worth of sequential sleeps.
+
const (
google = "www.google.com"
invalidDomain = "invalid.invalid" // RFC 2606 reserves .invalid
@@ -944,9 +945,15 @@ func TestLookupHostCancel(t *testing.T) {
if err == nil {
t.Fatalf("LookupHost(%q): returns %v, but should fail", invalidDomain, addr)
}
- if !strings.Contains(err.Error(), "canceled") {
- t.Fatalf("LookupHost(%q): failed with unexpected error: %v", invalidDomain, err)
- }
+
+ // Don't verify what the actual error is.
+ // We know that it must be non-nil because the domain is invalid,
+ // but we don't have any guarantee that LookupHost actually bothers
+ // to check for cancellation on the fast path.
+ // (For example, it could use a local cache to avoid blocking entirely.)
+
+ // The lookup may deduplicate in-flight requests, so give it time to settle
+ // in between.
time.Sleep(time.Millisecond * 1)
}
@@ -1050,7 +1057,7 @@ func TestLookupIPAddrPreservesContextValues(t *testing.T) {
defer func() { testHookLookupIP = origTestHookLookupIP }()
keyValues := []struct {
- key, value interface{}
+ key, value any
}{
{"key-1", 12},
{384, "value2"},
@@ -1267,3 +1274,71 @@ func TestResolverLookupIP(t *testing.T) {
})
}
}
+
+// A context timeout should still return a DNSError.
+func TestDNSTimeout(t *testing.T) {
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ defer dnsWaitGroup.Wait()
+
+ timeoutHookGo := make(chan bool, 1)
+ timeoutHook := func(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ <-timeoutHookGo
+ return nil, context.DeadlineExceeded
+ }
+ testHookLookupIP = timeoutHook
+
+ checkErr := func(err error) {
+ t.Helper()
+ if err == nil {
+ t.Error("expected an error")
+ } else if dnserr, ok := err.(*DNSError); !ok {
+ t.Errorf("got error type %T, want %T", err, (*DNSError)(nil))
+ } else if !dnserr.IsTimeout {
+ t.Errorf("got error %#v, want IsTimeout == true", dnserr)
+ } else if isTimeout := dnserr.Timeout(); !isTimeout {
+ t.Errorf("got err.Timeout() == %t, want true", isTimeout)
+ }
+ }
+
+ // Single lookup.
+ timeoutHookGo <- true
+ _, err := LookupIP("golang.org")
+ checkErr(err)
+
+ // Double lookup.
+ var err1, err2 error
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ _, err1 = LookupIP("golang1.org")
+ }()
+ go func() {
+ defer wg.Done()
+ _, err2 = LookupIP("golang1.org")
+ }()
+ close(timeoutHookGo)
+ wg.Wait()
+ checkErr(err1)
+ checkErr(err2)
+
+ // Double lookup with context.
+ timeoutHookGo = make(chan bool)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ _, err1 = DefaultResolver.LookupIPAddr(ctx, "golang2.org")
+ }()
+ go func() {
+ defer wg.Done()
+ _, err2 = DefaultResolver.LookupIPAddr(ctx, "golang2.org")
+ }()
+ time.Sleep(10 * time.Nanosecond)
+ close(timeoutHookGo)
+ wg.Wait()
+ checkErr(err1)
+ checkErr(err2)
+ cancel()
+}
diff --git a/libgo/go/net/lookup_unix.go b/libgo/go/net/lookup_unix.go
index 05f49b0..0d25f22 100644
--- a/libgo/go/net/lookup_unix.go
+++ b/libgo/go/net/lookup_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go
index bb34a08..27e5f86 100644
--- a/libgo/go/net/lookup_windows.go
+++ b/libgo/go/net/lookup_windows.go
@@ -226,7 +226,7 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
// if there are no aliases, the canonical name is the input name
- return absDomainName([]byte(name)), nil
+ return absDomainName(name), nil
}
if e != nil {
return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
@@ -235,7 +235,7 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
cname := windows.UTF16PtrToString(resolved)
- return absDomainName([]byte(cname)), nil
+ return absDomainName(cname), nil
}
func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
@@ -258,10 +258,10 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st
srvs := make([]*SRV, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
- srvs = append(srvs, &SRV{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]))), v.Port, v.Priority, v.Weight})
+ srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight})
}
byPriorityWeight(srvs).sort()
- return absDomainName([]byte(target)), srvs, nil
+ return absDomainName(target), srvs, nil
}
func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
@@ -278,7 +278,7 @@ func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
mxs := make([]*MX, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
- mxs = append(mxs, &MX{absDomainName([]byte(windows.UTF16PtrToString(v.NameExchange))), v.Preference})
+ mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference})
}
byPref(mxs).sort()
return mxs, nil
@@ -298,7 +298,7 @@ func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
nss := make([]*NS, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
- nss = append(nss, &NS{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])))})
+ nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))})
}
return nss, nil
}
@@ -344,7 +344,7 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error)
ptrs := make([]string, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
- ptrs = append(ptrs, absDomainName([]byte(windows.UTF16PtrToString(v.Host))))
+ ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host)))
}
return ptrs, nil
}
diff --git a/libgo/go/net/lookup_windows_test.go b/libgo/go/net/lookup_windows_test.go
index aa95501..9254733 100644
--- a/libgo/go/net/lookup_windows_test.go
+++ b/libgo/go/net/lookup_windows_test.go
@@ -21,7 +21,7 @@ import (
var nslookupTestServers = []string{"mail.golang.com", "gmail.com"}
var lookupTestIPs = []string{"8.8.8.8", "1.1.1.1"}
-func toJson(v interface{}) string {
+func toJson(v any) string {
data, _ := json.Marshal(v)
return string(data)
}
@@ -220,14 +220,14 @@ func nslookupMX(name string) (mx []*MX, err error) {
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
pref, _, _ := dtoi(ans[2])
- mx = append(mx, &MX{absDomainName([]byte(ans[3])), uint16(pref)})
+ mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)})
}
// windows nslookup syntax
// gmail.com MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com
rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
pref, _, _ := dtoi(ans[2])
- mx = append(mx, &MX{absDomainName([]byte(ans[3])), uint16(pref)})
+ mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)})
}
return
}
@@ -241,7 +241,7 @@ func nslookupNS(name string) (ns []*NS, err error) {
// golang.org nameserver = ns1.google.com.
rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
- ns = append(ns, &NS{absDomainName([]byte(ans[2]))})
+ ns = append(ns, &NS{absDomainName(ans[2])})
}
return
}
@@ -258,7 +258,7 @@ func nslookupCNAME(name string) (cname string, err error) {
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
last = ans[2]
}
- return absDomainName([]byte(last)), nil
+ return absDomainName(last), nil
}
func nslookupTXT(name string) (txt []string, err error) {
@@ -299,7 +299,7 @@ func lookupPTR(name string) (ptr []string, err error) {
ptr = make([]string, 0, 10)
rx := regexp.MustCompile(`(?m)^Pinging\s+([a-zA-Z0-9.\-]+)\s+\[.*$`)
for _, ans := range rx.FindAllStringSubmatch(r, -1) {
- ptr = append(ptr, absDomainName([]byte(ans[1])))
+ ptr = append(ptr, absDomainName(ans[1]))
}
return
}
diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go
index 47bbf6c..985b6fc 100644
--- a/libgo/go/net/mail/message.go
+++ b/libgo/go/net/mail/message.go
@@ -35,7 +35,7 @@ var debug = debugT(false)
type debugT bool
-func (d debugT) Printf(format string, args ...interface{}) {
+func (d debugT) Printf(format string, args ...any) {
if d {
log.Printf(format, args...)
}
@@ -100,7 +100,7 @@ func ParseDate(date string) (time.Time, error) {
dateLayoutsBuildOnce.Do(buildDateLayouts)
// CR and LF must match and are tolerated anywhere in the date field.
date = strings.ReplaceAll(date, "\r\n", "")
- if strings.Index(date, "\r") != -1 {
+ if strings.Contains(date, "\r") {
return time.Time{}, errors.New("mail: header has a CR without LF")
}
// Re-using some addrParser methods which support obsolete text, i.e. non-printable ASCII
diff --git a/libgo/go/net/main_cloexec_test.go b/libgo/go/net/main_cloexec_test.go
index 03f7d63..06f0671 100644
--- a/libgo/go/net/main_cloexec_test.go
+++ b/libgo/go/net/main_cloexec_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || hurd || illumos || linux || netbsd || openbsd
-// +build dragonfly freebsd hurd illumos linux netbsd openbsd
package net
diff --git a/libgo/go/net/main_conf_test.go b/libgo/go/net/main_conf_test.go
index 645b267..41b78ed 100644
--- a/libgo/go/net/main_conf_test.go
+++ b/libgo/go/net/main_conf_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js && !plan9 && !windows
-// +build !js,!plan9,!windows
package net
diff --git a/libgo/go/net/main_noconf_test.go b/libgo/go/net/main_noconf_test.go
index bcea630..ab050fa 100644
--- a/libgo/go/net/main_noconf_test.go
+++ b/libgo/go/net/main_noconf_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build (js && wasm) || plan9 || windows
-// +build js,wasm plan9 windows
package net
diff --git a/libgo/go/net/main_posix_test.go b/libgo/go/net/main_posix_test.go
index c9ab25a..8899aa9 100644
--- a/libgo/go/net/main_posix_test.go
+++ b/libgo/go/net/main_posix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js && !plan9
-// +build !js,!plan9
package net
@@ -18,9 +17,9 @@ func enableSocketConnect() {
}
func disableSocketConnect(network string) {
- ss := strings.Split(network, ":")
+ net, _, _ := strings.Cut(network, ":")
sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
- switch ss[0] {
+ switch net {
case "tcp4":
if so.Cookie.Family() == syscall.AF_INET && so.Cookie.Type() == syscall.SOCK_STREAM {
return nil, syscall.EHOSTUNREACH
diff --git a/libgo/go/net/main_test.go b/libgo/go/net/main_test.go
index dc17d3f..1ee8c2e 100644
--- a/libgo/go/net/main_test.go
+++ b/libgo/go/net/main_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -174,11 +173,8 @@ func runningGoroutines() []string {
b := make([]byte, 2<<20)
b = b[:runtime.Stack(b, true)]
for _, s := range strings.Split(string(b), "\n\n") {
- ss := strings.SplitN(s, "\n", 2)
- if len(ss) != 2 {
- continue
- }
- stack := strings.TrimSpace(ss[1])
+ _, stack, _ := strings.Cut(s, "\n")
+ stack = strings.TrimSpace(stack)
if !strings.Contains(stack, "created by net") {
continue
}
diff --git a/libgo/go/net/main_unix_test.go b/libgo/go/net/main_unix_test.go
index 367cefc..402da4d 100644
--- a/libgo/go/net/main_unix_test.go
+++ b/libgo/go/net/main_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/mockserver_test.go b/libgo/go/net/mockserver_test.go
index b50a1e5..186bd33 100644
--- a/libgo/go/net/mockserver_test.go
+++ b/libgo/go/net/mockserver_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -11,46 +10,67 @@ import (
"errors"
"fmt"
"os"
+ "path/filepath"
"sync"
"testing"
"time"
)
-// testUnixAddr uses os.CreateTemp to get a name that is unique.
-func testUnixAddr() string {
- f, err := os.CreateTemp("", "go-nettest")
+// testUnixAddr uses os.MkdirTemp to get a name that is unique.
+func testUnixAddr(t testing.TB) string {
+ // Pass an empty pattern to get a directory name that is as short as possible.
+ // If we end up with a name longer than the sun_path field in the sockaddr_un
+ // struct, we won't be able to make the syscall to open the socket.
+ d, err := os.MkdirTemp("", "")
if err != nil {
- panic(err)
+ t.Fatal(err)
}
- addr := f.Name()
- f.Close()
- os.Remove(addr)
- return addr
+ t.Cleanup(func() {
+ if err := os.RemoveAll(d); err != nil {
+ t.Error(err)
+ }
+ })
+ return filepath.Join(d, "sock")
}
-func newLocalListener(network string) (Listener, error) {
+func newLocalListener(t testing.TB, network string) Listener {
+ listen := func(net, addr string) Listener {
+ ln, err := Listen(net, addr)
+ if err != nil {
+ t.Helper()
+ t.Fatal(err)
+ }
+ return ln
+ }
+
switch network {
case "tcp":
if supportsIPv4() {
+ if !supportsIPv6() {
+ return listen("tcp4", "127.0.0.1:0")
+ }
if ln, err := Listen("tcp4", "127.0.0.1:0"); err == nil {
- return ln, nil
+ return ln
}
}
if supportsIPv6() {
- return Listen("tcp6", "[::1]:0")
+ return listen("tcp6", "[::1]:0")
}
case "tcp4":
if supportsIPv4() {
- return Listen("tcp4", "127.0.0.1:0")
+ return listen("tcp4", "127.0.0.1:0")
}
case "tcp6":
if supportsIPv6() {
- return Listen("tcp6", "[::1]:0")
+ return listen("tcp6", "[::1]:0")
}
case "unix", "unixpacket":
- return Listen(network, testUnixAddr())
+ return listen(network, testUnixAddr(t))
}
- return nil, fmt.Errorf("%s is not supported", network)
+
+ t.Helper()
+ t.Fatalf("%s is not supported", network)
+ return nil
}
func newDualStackListener() (lns []*TCPListener, err error) {
@@ -121,12 +141,10 @@ func (ls *localServer) teardown() error {
return nil
}
-func newLocalServer(network string) (*localServer, error) {
- ln, err := newLocalListener(network)
- if err != nil {
- return nil, err
- }
- return &localServer{Listener: ln, done: make(chan bool)}, nil
+func newLocalServer(t testing.TB, network string) *localServer {
+ t.Helper()
+ ln := newLocalListener(t, network)
+ return &localServer{Listener: ln, done: make(chan bool)}
}
type streamListener struct {
@@ -135,8 +153,8 @@ type streamListener struct {
done chan bool // signal that indicates server stopped
}
-func (sl *streamListener) newLocalServer() (*localServer, error) {
- return &localServer{Listener: sl.Listener, done: make(chan bool)}, nil
+func (sl *streamListener) newLocalServer() *localServer {
+ return &localServer{Listener: sl.Listener, done: make(chan bool)}
}
type dualStackServer struct {
@@ -288,75 +306,39 @@ func transceiver(c Conn, wb []byte, ch chan<- error) {
}
}
-func timeoutReceiver(c Conn, d, min, max time.Duration, ch chan<- error) {
- var err error
- defer func() { ch <- err }()
-
- t0 := time.Now()
- if err = c.SetReadDeadline(time.Now().Add(d)); err != nil {
- return
- }
- b := make([]byte, 256)
- var n int
- n, err = c.Read(b)
- t1 := time.Now()
- if n != 0 || err == nil || !err.(Error).Timeout() {
- err = fmt.Errorf("Read did not return (0, timeout): (%d, %v)", n, err)
- return
- }
- if dt := t1.Sub(t0); min > dt || dt > max && !testing.Short() {
- err = fmt.Errorf("Read took %s; expected %s", dt, d)
- return
- }
-}
-
-func timeoutTransmitter(c Conn, d, min, max time.Duration, ch chan<- error) {
- var err error
- defer func() { ch <- err }()
-
- t0 := time.Now()
- if err = c.SetWriteDeadline(time.Now().Add(d)); err != nil {
- return
- }
- var n int
- for {
- n, err = c.Write([]byte("TIMEOUT TRANSMITTER"))
+func newLocalPacketListener(t testing.TB, network string) PacketConn {
+ listenPacket := func(net, addr string) PacketConn {
+ c, err := ListenPacket(net, addr)
if err != nil {
- break
+ t.Helper()
+ t.Fatal(err)
}
+ return c
}
- t1 := time.Now()
- if err == nil || !err.(Error).Timeout() {
- err = fmt.Errorf("Write did not return (any, timeout): (%d, %v)", n, err)
- return
- }
- if dt := t1.Sub(t0); min > dt || dt > max && !testing.Short() {
- err = fmt.Errorf("Write took %s; expected %s", dt, d)
- return
- }
-}
-func newLocalPacketListener(network string) (PacketConn, error) {
switch network {
case "udp":
if supportsIPv4() {
- return ListenPacket("udp4", "127.0.0.1:0")
+ return listenPacket("udp4", "127.0.0.1:0")
}
if supportsIPv6() {
- return ListenPacket("udp6", "[::1]:0")
+ return listenPacket("udp6", "[::1]:0")
}
case "udp4":
if supportsIPv4() {
- return ListenPacket("udp4", "127.0.0.1:0")
+ return listenPacket("udp4", "127.0.0.1:0")
}
case "udp6":
if supportsIPv6() {
- return ListenPacket("udp6", "[::1]:0")
+ return listenPacket("udp6", "[::1]:0")
}
case "unixgram":
- return ListenPacket(network, testUnixAddr())
+ return listenPacket(network, testUnixAddr(t))
}
- return nil, fmt.Errorf("%s is not supported", network)
+
+ t.Helper()
+ t.Fatalf("%s is not supported", network)
+ return nil
}
func newDualStackPacketListener() (cs []*UDPConn, err error) {
@@ -421,20 +403,18 @@ func (ls *localPacketServer) teardown() error {
return nil
}
-func newLocalPacketServer(network string) (*localPacketServer, error) {
- c, err := newLocalPacketListener(network)
- if err != nil {
- return nil, err
- }
- return &localPacketServer{PacketConn: c, done: make(chan bool)}, nil
+func newLocalPacketServer(t testing.TB, network string) *localPacketServer {
+ t.Helper()
+ c := newLocalPacketListener(t, network)
+ return &localPacketServer{PacketConn: c, done: make(chan bool)}
}
type packetListener struct {
PacketConn
}
-func (pl *packetListener) newLocalServer() (*localPacketServer, error) {
- return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}, nil
+func (pl *packetListener) newLocalServer() *localPacketServer {
+ return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}
}
func packetTransponder(c PacketConn, ch chan<- error) {
@@ -505,25 +485,3 @@ func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) {
ch <- fmt.Errorf("read %d; want %d", n, len(wb))
}
}
-
-func timeoutPacketReceiver(c PacketConn, d, min, max time.Duration, ch chan<- error) {
- var err error
- defer func() { ch <- err }()
-
- t0 := time.Now()
- if err = c.SetReadDeadline(time.Now().Add(d)); err != nil {
- return
- }
- b := make([]byte, 256)
- var n int
- n, _, err = c.ReadFrom(b)
- t1 := time.Now()
- if n != 0 || err == nil || !err.(Error).Timeout() {
- err = fmt.Errorf("ReadFrom did not return (0, timeout): (%d, %v)", n, err)
- return
- }
- if dt := t1.Sub(t0); min > dt || dt > max && !testing.Short() {
- err = fmt.Errorf("ReadFrom took %s; expected %s", dt, d)
- return
- }
-}
diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go
index a7c65ff..77e54a9 100644
--- a/libgo/go/net/net.go
+++ b/libgo/go/net/net.go
@@ -125,10 +125,10 @@ type Conn interface {
// Any blocked Read or Write operations will be unblocked and return errors.
Close() error
- // LocalAddr returns the local network address.
+ // LocalAddr returns the local network address, if known.
LocalAddr() Addr
- // RemoteAddr returns the remote network address.
+ // RemoteAddr returns the remote network address, if known.
RemoteAddr() Addr
// SetDeadline sets the read and write deadlines associated
@@ -328,7 +328,7 @@ type PacketConn interface {
// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors.
Close() error
- // LocalAddr returns the local network address.
+ // LocalAddr returns the local network address, if known.
LocalAddr() Addr
// SetDeadline sets the read and write deadlines associated
@@ -396,8 +396,12 @@ type Listener interface {
// An Error represents a network error.
type Error interface {
error
- Timeout() bool // Is the error a timeout?
- Temporary() bool // Is the error temporary?
+ Timeout() bool // Is the error a timeout?
+
+ // Deprecated: Temporary errors are not well-defined.
+ // Most "temporary" errors are timeouts, and the few exceptions are surprising.
+ // Do not use this method.
+ Temporary() bool
}
// Various errors contained in OpError.
diff --git a/libgo/go/net/net_fake.go b/libgo/go/net/net_fake.go
index 74fc1da..ee5644c 100644
--- a/libgo/go/net/net_fake.go
+++ b/libgo/go/net/net_fake.go
@@ -5,7 +5,6 @@
// Fake networking for js/wasm. It is intended to allow tests of other package to pass.
//go:build js && wasm
-// +build js,wasm
package net
@@ -266,16 +265,48 @@ func sysSocket(family, sotype, proto int) (int, error) {
func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
return 0, nil, syscall.ENOSYS
+
+}
+func (fd *netFD) readFromInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *netFD) readFromInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
+ return 0, syscall.ENOSYS
}
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
return 0, 0, 0, nil, syscall.ENOSYS
}
+func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
+ return 0, 0, 0, syscall.ENOSYS
+}
+
+func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
+ return 0, 0, 0, syscall.ENOSYS
+}
+
+func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
+ return 0, 0, syscall.ENOSYS
+}
+
+func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
+ return 0, 0, syscall.ENOSYS
+}
+
func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
return 0, syscall.ENOSYS
}
+func (fd *netFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
+ return 0, syscall.ENOSYS
+}
+
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
return 0, 0, syscall.ENOSYS
}
diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go
index 6e7be4d..7b16991 100644
--- a/libgo/go/net/net_test.go
+++ b/libgo/go/net/net_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -34,10 +33,7 @@ func TestCloseRead(t *testing.T) {
}
t.Parallel()
- ln, err := newLocalListener(network)
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, network)
switch network {
case "unix", "unixpacket":
defer os.Remove(ln.Addr().String())
@@ -133,10 +129,7 @@ func TestCloseWrite(t *testing.T) {
}
}
- ls, err := newLocalServer(network)
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, network)
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -190,10 +183,7 @@ func TestConnClose(t *testing.T) {
}
t.Parallel()
- ln, err := newLocalListener(network)
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, network)
switch network {
case "unix", "unixpacket":
defer os.Remove(ln.Addr().String())
@@ -235,16 +225,12 @@ func TestListenerClose(t *testing.T) {
}
t.Parallel()
- ln, err := newLocalListener(network)
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, network)
switch network {
case "unix", "unixpacket":
defer os.Remove(ln.Addr().String())
}
- dst := ln.Addr().String()
if err := ln.Close(); err != nil {
if perr := parseCloseError(err, false); perr != nil {
t.Error(perr)
@@ -257,28 +243,12 @@ func TestListenerClose(t *testing.T) {
t.Fatal("should fail")
}
- if network == "tcp" {
- // We will have two TCP FSMs inside the
- // kernel here. There's no guarantee that a
- // signal comes from the far end FSM will be
- // delivered immediately to the near end FSM,
- // especially on the platforms that allow
- // multiple consumer threads to pull pending
- // established connections at the same time by
- // enabling SO_REUSEPORT option such as Linux,
- // DragonFly BSD. So we need to give some time
- // quantum to the kernel.
- //
- // Note that net.inet.tcp.reuseport_ext=1 by
- // default on DragonFly BSD.
- time.Sleep(time.Millisecond)
-
- cc, err := Dial("tcp", dst)
- if err == nil {
- t.Error("Dial to closed TCP listener succeeded.")
- cc.Close()
- }
- }
+ // Note: we cannot ensure that a subsequent Dial does not succeed, because
+ // we do not in general have any guarantee that ln.Addr is not immediately
+ // reused. (TCP sockets enter a TIME_WAIT state when closed, but that only
+ // applies to existing connections for the port โ€” it does not prevent the
+ // port itself from being used for entirely new connections in the
+ // meantime.)
})
}
}
@@ -293,10 +263,7 @@ func TestPacketConnClose(t *testing.T) {
}
t.Parallel()
- c, err := newLocalPacketListener(network)
- if err != nil {
- t.Fatal(err)
- }
+ c := newLocalPacketListener(t, network)
switch network {
case "unixgram":
defer os.Remove(c.LocalAddr().String())
@@ -321,18 +288,17 @@ func TestPacketConnClose(t *testing.T) {
func TestListenCloseListen(t *testing.T) {
const maxTries = 10
for tries := 0; tries < maxTries; tries++ {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
addr := ln.Addr().String()
+ // TODO: This is racy. The selected address could be reused in between this
+ // Close and the subsequent Listen.
if err := ln.Close(); err != nil {
if perr := parseCloseError(err, false); perr != nil {
t.Error(perr)
}
t.Fatal(err)
}
- ln, err = Listen("tcp", addr)
+ ln, err := Listen("tcp", addr)
if err == nil {
// Success. (This test didn't always make it here earlier.)
ln.Close()
@@ -378,10 +344,7 @@ func TestAcceptIgnoreAbortedConnRequest(t *testing.T) {
}
c.Close()
}
- ls, err := newLocalServer("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, "tcp")
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -408,10 +371,7 @@ func TestZeroByteRead(t *testing.T) {
}
t.Parallel()
- ln, err := newLocalListener(network)
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, network)
connc := make(chan Conn, 1)
go func() {
defer ln.Close()
@@ -460,10 +420,7 @@ func TestZeroByteRead(t *testing.T) {
// runs peer1 and peer2 concurrently. withTCPConnPair returns when
// both have completed.
func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
errc := make(chan error, 2)
go func() {
diff --git a/libgo/go/net/netip/export_test.go b/libgo/go/net/netip/export_test.go
new file mode 100644
index 0000000..59971fa
--- /dev/null
+++ b/libgo/go/net/netip/export_test.go
@@ -0,0 +1,30 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import "internal/intern"
+
+var (
+ Z0 = z0
+ Z4 = z4
+ Z6noz = z6noz
+)
+
+type Uint128 = uint128
+
+func Mk128(hi, lo uint64) Uint128 {
+ return uint128{hi, lo}
+}
+
+func MkAddr(u Uint128, z *intern.Value) Addr {
+ return Addr{u, z}
+}
+
+func IPv4(a, b, c, d uint8) Addr { return AddrFrom4([4]byte{a, b, c, d}) }
+
+var TestAppendToMarshal = testAppendToMarshal
+
+func (a Addr) IsZero() bool { return a.isZero() }
+func (p Prefix) IsZero() bool { return p.isZero() }
diff --git a/libgo/go/net/netip/fuzz_test.go b/libgo/go/net/netip/fuzz_test.go
new file mode 100644
index 0000000..4edbcf6
--- /dev/null
+++ b/libgo/go/net/netip/fuzz_test.go
@@ -0,0 +1,353 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build ignore_due_to_generics
+
+package netip_test
+
+import (
+ "bytes"
+ "encoding"
+ "fmt"
+ "net"
+ . "net/netip"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+var corpus = []string{
+ // Basic zero IPv4 address.
+ "0.0.0.0",
+ // Basic non-zero IPv4 address.
+ "192.168.140.255",
+ // IPv4 address in windows-style "print all the digits" form.
+ "010.000.015.001",
+ // IPv4 address with a silly amount of leading zeros.
+ "000001.00000002.00000003.000000004",
+ // 4-in-6 with octet with leading zero
+ "::ffff:1.2.03.4",
+ // Basic zero IPv6 address.
+ "::",
+ // Localhost IPv6.
+ "::1",
+ // Fully expanded IPv6 address.
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b",
+ // IPv6 with elided fields in the middle.
+ "fd7a:115c::626b:430b",
+ // IPv6 with elided fields at the end.
+ "fd7a:115c:a1e0:ab12:4843:cd96::",
+ // IPv6 with single elided field at the end.
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b::",
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b:0",
+ // IPv6 with single elided field in the middle.
+ "fd7a:115c:a1e0::4843:cd96:626b:430b",
+ "fd7a:115c:a1e0:0:4843:cd96:626b:430b",
+ // IPv6 with the trailing 32 bits written as IPv4 dotted decimal. (4in6)
+ "::ffff:192.168.140.255",
+ "::ffff:192.168.140.255",
+ // IPv6 with a zone specifier.
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b%eth0",
+ // IPv6 with dotted decimal and zone specifier.
+ "1:2::ffff:192.168.140.255%eth1",
+ "1:2::ffff:c0a8:8cff%eth1",
+ // IPv6 with capital letters.
+ "FD9E:1A04:F01D::1",
+ "fd9e:1a04:f01d::1",
+ // Empty string.
+ "",
+ // Garbage non-IP.
+ "bad",
+ // Single number. Some parsers accept this as an IPv4 address in
+ // big-endian uint32 form, but we don't.
+ "1234",
+ // IPv4 with a zone specifier.
+ "1.2.3.4%eth0",
+ // IPv4 field must have at least one digit.
+ ".1.2.3",
+ "1.2.3.",
+ "1..2.3",
+ // IPv4 address too long.
+ "1.2.3.4.5",
+ // IPv4 in dotted octal form.
+ "0300.0250.0214.0377",
+ // IPv4 in dotted hex form.
+ "0xc0.0xa8.0x8c.0xff",
+ // IPv4 in class B form.
+ "192.168.12345",
+ // IPv4 in class B form, with a small enough number to be
+ // parseable as a regular dotted decimal field.
+ "127.0.1",
+ // IPv4 in class A form.
+ "192.1234567",
+ // IPv4 in class A form, with a small enough number to be
+ // parseable as a regular dotted decimal field.
+ "127.1",
+ // IPv4 field has value >255.
+ "192.168.300.1",
+ // IPv4 with too many fields.
+ "192.168.0.1.5.6",
+ // IPv6 with not enough fields.
+ "1:2:3:4:5:6:7",
+ // IPv6 with too many fields.
+ "1:2:3:4:5:6:7:8:9",
+ // IPv6 with 8 fields and a :: expander.
+ "1:2:3:4::5:6:7:8",
+ // IPv6 with a field bigger than 2b.
+ "fe801::1",
+ // IPv6 with non-hex values in field.
+ "fe80:tail:scal:e::",
+ // IPv6 with a zone delimiter but no zone.
+ "fe80::1%",
+ // IPv6 with a zone specifier of zero.
+ "::ffff:0:0%0",
+ // IPv6 (without ellipsis) with too many fields for trailing embedded IPv4.
+ "ffff:ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255",
+ // IPv6 (with ellipsis) with too many fields for trailing embedded IPv4.
+ "ffff::ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255",
+ // IPv6 with invalid embedded IPv4.
+ "::ffff:192.168.140.bad",
+ // IPv6 with multiple ellipsis ::.
+ "fe80::1::1",
+ // IPv6 with invalid non hex/colon character.
+ "fe80:1?:1",
+ // IPv6 with truncated bytes after single colon.
+ "fe80:",
+ // AddrPort strings.
+ "1.2.3.4:51820",
+ "[fd7a:115c:a1e0:ab12:4843:cd96:626b:430b]:80",
+ "[::ffff:c000:0280]:65535",
+ "[::ffff:c000:0280%eth0]:1",
+ // Prefix strings.
+ "1.2.3.4/24",
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118",
+ "::ffff:c000:0280/96",
+ "::ffff:c000:0280%eth0/37",
+}
+
+func FuzzParse(f *testing.F) {
+ for _, seed := range corpus {
+ f.Add(seed)
+ }
+
+ f.Fuzz(func(t *testing.T, s string) {
+ ip, _ := ParseAddr(s)
+ checkStringParseRoundTrip(t, ip, ParseAddr)
+ checkEncoding(t, ip)
+
+ // Check that we match the net's IP parser, modulo zones.
+ if !strings.Contains(s, "%") {
+ stdip := net.ParseIP(s)
+ if !ip.IsValid() != (stdip == nil) {
+ t.Errorf("ParseAddr zero != net.ParseIP nil: ip=%q stdip=%q", ip, stdip)
+ }
+
+ if ip.IsValid() && !ip.Is4In6() {
+ buf, err := ip.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf2, err := stdip.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(buf, buf2) {
+ t.Errorf("Addr.MarshalText() != net.IP.MarshalText(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.String() != stdip.String() {
+ t.Errorf("Addr.String() != net.IP.String(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsGlobalUnicast() != stdip.IsGlobalUnicast() {
+ t.Errorf("Addr.IsGlobalUnicast() != net.IP.IsGlobalUnicast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsInterfaceLocalMulticast() != stdip.IsInterfaceLocalMulticast() {
+ t.Errorf("Addr.IsInterfaceLocalMulticast() != net.IP.IsInterfaceLocalMulticast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsLinkLocalMulticast() != stdip.IsLinkLocalMulticast() {
+ t.Errorf("Addr.IsLinkLocalMulticast() != net.IP.IsLinkLocalMulticast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsLinkLocalUnicast() != stdip.IsLinkLocalUnicast() {
+ t.Errorf("Addr.IsLinkLocalUnicast() != net.IP.IsLinkLocalUnicast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsLoopback() != stdip.IsLoopback() {
+ t.Errorf("Addr.IsLoopback() != net.IP.IsLoopback(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsMulticast() != stdip.IsMulticast() {
+ t.Errorf("Addr.IsMulticast() != net.IP.IsMulticast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsPrivate() != stdip.IsPrivate() {
+ t.Errorf("Addr.IsPrivate() != net.IP.IsPrivate(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsUnspecified() != stdip.IsUnspecified() {
+ t.Errorf("Addr.IsUnspecified() != net.IP.IsUnspecified(): ip=%q stdip=%q", ip, stdip)
+ }
+ }
+ }
+
+ // Check that .Next().Prev() and .Prev().Next() preserve the IP.
+ if ip.IsValid() && ip.Next().IsValid() && ip.Next().Prev() != ip {
+ t.Errorf(".Next.Prev did not round trip: ip=%q .next=%q .next.prev=%q", ip, ip.Next(), ip.Next().Prev())
+ }
+ if ip.IsValid() && ip.Prev().IsValid() && ip.Prev().Next() != ip {
+ t.Errorf(".Prev.Next did not round trip: ip=%q .prev=%q .prev.next=%q", ip, ip.Prev(), ip.Prev().Next())
+ }
+
+ port, err := ParseAddrPort(s)
+ if err == nil {
+ checkStringParseRoundTrip(t, port, ParseAddrPort)
+ checkEncoding(t, port)
+ }
+ port = AddrPortFrom(ip, 80)
+ checkStringParseRoundTrip(t, port, ParseAddrPort)
+ checkEncoding(t, port)
+
+ ipp, err := ParsePrefix(s)
+ if err == nil {
+ checkStringParseRoundTrip(t, ipp, ParsePrefix)
+ checkEncoding(t, ipp)
+ }
+ ipp = PrefixFrom(ip, 8)
+ checkStringParseRoundTrip(t, ipp, ParsePrefix)
+ checkEncoding(t, ipp)
+ })
+}
+
+// checkTextMarshaler checks that x's MarshalText and UnmarshalText functions round trip correctly.
+func checkTextMarshaler(t *testing.T, x encoding.TextMarshaler) {
+ buf, err := x.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ y := reflect.New(reflect.TypeOf(x)).Interface().(encoding.TextUnmarshaler)
+ err = y.UnmarshalText(buf)
+ if err != nil {
+ t.Logf("(%v).MarshalText() = %q", x, buf)
+ t.Fatalf("(%T).UnmarshalText(%q) = %v", y, buf, err)
+ }
+ e := reflect.ValueOf(y).Elem().Interface()
+ if !reflect.DeepEqual(x, e) {
+ t.Logf("(%v).MarshalText() = %q", x, buf)
+ t.Logf("(%T).UnmarshalText(%q) = %v", y, buf, y)
+ t.Fatalf("MarshalText/UnmarshalText failed to round trip: %#v != %#v", x, e)
+ }
+ buf2, err := y.(encoding.TextMarshaler).MarshalText()
+ if err != nil {
+ t.Logf("(%v).MarshalText() = %q", x, buf)
+ t.Logf("(%T).UnmarshalText(%q) = %v", y, buf, y)
+ t.Fatalf("failed to MarshalText a second time: %v", err)
+ }
+ if !bytes.Equal(buf, buf2) {
+ t.Logf("(%v).MarshalText() = %q", x, buf)
+ t.Logf("(%T).UnmarshalText(%q) = %v", y, buf, y)
+ t.Logf("(%v).MarshalText() = %q", y, buf2)
+ t.Fatalf("second MarshalText differs from first: %q != %q", buf, buf2)
+ }
+}
+
+// checkBinaryMarshaler checks that x's MarshalText and UnmarshalText functions round trip correctly.
+func checkBinaryMarshaler(t *testing.T, x encoding.BinaryMarshaler) {
+ buf, err := x.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ y := reflect.New(reflect.TypeOf(x)).Interface().(encoding.BinaryUnmarshaler)
+ err = y.UnmarshalBinary(buf)
+ if err != nil {
+ t.Logf("(%v).MarshalBinary() = %q", x, buf)
+ t.Fatalf("(%T).UnmarshalBinary(%q) = %v", y, buf, err)
+ }
+ e := reflect.ValueOf(y).Elem().Interface()
+ if !reflect.DeepEqual(x, e) {
+ t.Logf("(%v).MarshalBinary() = %q", x, buf)
+ t.Logf("(%T).UnmarshalBinary(%q) = %v", y, buf, y)
+ t.Fatalf("MarshalBinary/UnmarshalBinary failed to round trip: %#v != %#v", x, e)
+ }
+ buf2, err := y.(encoding.BinaryMarshaler).MarshalBinary()
+ if err != nil {
+ t.Logf("(%v).MarshalBinary() = %q", x, buf)
+ t.Logf("(%T).UnmarshalBinary(%q) = %v", y, buf, y)
+ t.Fatalf("failed to MarshalBinary a second time: %v", err)
+ }
+ if !bytes.Equal(buf, buf2) {
+ t.Logf("(%v).MarshalBinary() = %q", x, buf)
+ t.Logf("(%T).UnmarshalBinary(%q) = %v", y, buf, y)
+ t.Logf("(%v).MarshalBinary() = %q", y, buf2)
+ t.Fatalf("second MarshalBinary differs from first: %q != %q", buf, buf2)
+ }
+}
+
+func checkTextMarshalMatchesString(t *testing.T, x netipType) {
+ buf, err := x.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ str := x.String()
+ if string(buf) != str {
+ t.Fatalf("%v: MarshalText = %q, String = %q", x, buf, str)
+ }
+}
+
+type appendMarshaler interface {
+ encoding.TextMarshaler
+ AppendTo([]byte) []byte
+}
+
+// checkTextMarshalMatchesAppendTo checks that x's MarshalText matches x's AppendTo.
+func checkTextMarshalMatchesAppendTo(t *testing.T, x appendMarshaler) {
+ buf, err := x.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf2 := make([]byte, 0, len(buf))
+ buf2 = x.AppendTo(buf2)
+ if !bytes.Equal(buf, buf2) {
+ t.Fatalf("%v: MarshalText = %q, AppendTo = %q", x, buf, buf2)
+ }
+}
+
+type netipType interface {
+ encoding.BinaryMarshaler
+ encoding.TextMarshaler
+ fmt.Stringer
+ IsValid() bool
+}
+
+type netipTypeCmp interface {
+ comparable
+ netipType
+}
+
+// checkStringParseRoundTrip checks that x's String method and the provided parse function can round trip correctly.
+func checkStringParseRoundTrip[P netipTypeCmp](t *testing.T, x P, parse func(string) (P, error)) {
+ if !x.IsValid() {
+ // Ignore invalid values.
+ return
+ }
+
+ s := x.String()
+ y, err := parse(s)
+ if err != nil {
+ t.Fatalf("s=%q err=%v", s, err)
+ }
+ if x != y {
+ t.Fatalf("%T round trip identity failure: s=%q x=%#v y=%#v", x, s, x, y)
+ }
+ s2 := y.String()
+ if s != s2 {
+ t.Fatalf("%T String round trip identity failure: s=%#v s2=%#v", x, s, s2)
+ }
+}
+
+func checkEncoding(t *testing.T, x netipType) {
+ if x.IsValid() {
+ checkTextMarshaler(t, x)
+ checkBinaryMarshaler(t, x)
+ checkTextMarshalMatchesString(t, x)
+ }
+
+ if am, ok := x.(appendMarshaler); ok {
+ checkTextMarshalMatchesAppendTo(t, am)
+ }
+}
diff --git a/libgo/go/net/netip/inlining_test.go b/libgo/go/net/netip/inlining_test.go
new file mode 100644
index 0000000..107fe1f
--- /dev/null
+++ b/libgo/go/net/netip/inlining_test.go
@@ -0,0 +1,110 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import (
+ "internal/testenv"
+ "os/exec"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "strings"
+ "testing"
+)
+
+func TestInlining(t *testing.T) {
+ testenv.MustHaveGoBuild(t)
+ t.Parallel()
+ var exe string
+ if runtime.GOOS == "windows" {
+ exe = ".exe"
+ }
+ out, err := exec.Command(
+ filepath.Join(runtime.GOROOT(), "bin", "go"+exe),
+ "build",
+ "--gcflags=-m",
+ "net/netip").CombinedOutput()
+ if err != nil {
+ t.Fatalf("go build: %v, %s", err, out)
+ }
+ got := map[string]bool{}
+ regexp.MustCompile(` can inline (\S+)`).ReplaceAllFunc(out, func(match []byte) []byte {
+ got[strings.TrimPrefix(string(match), " can inline ")] = true
+ return nil
+ })
+ wantInlinable := []string{
+ "(*uint128).halves",
+ "Addr.BitLen",
+ "Addr.hasZone",
+ "Addr.Is4",
+ "Addr.Is4In6",
+ "Addr.Is6",
+ "Addr.IsLoopback",
+ "Addr.IsMulticast",
+ "Addr.IsInterfaceLocalMulticast",
+ "Addr.IsValid",
+ "Addr.IsUnspecified",
+ "Addr.Less",
+ "Addr.lessOrEq",
+ "Addr.Unmap",
+ "Addr.Zone",
+ "Addr.v4",
+ "Addr.v6",
+ "Addr.v6u16",
+ "Addr.withoutZone",
+ "AddrPortFrom",
+ "AddrPort.Addr",
+ "AddrPort.Port",
+ "AddrPort.IsValid",
+ "Prefix.IsSingleIP",
+ "Prefix.Masked",
+ "Prefix.IsValid",
+ "PrefixFrom",
+ "Prefix.Addr",
+ "Prefix.Bits",
+ "AddrFrom4",
+ "IPv6LinkLocalAllNodes",
+ "IPv6Unspecified",
+ "MustParseAddr",
+ "MustParseAddrPort",
+ "MustParsePrefix",
+ "appendDecimal",
+ "appendHex",
+ "uint128.addOne",
+ "uint128.and",
+ "uint128.bitsClearedFrom",
+ "uint128.bitsSetFrom",
+ "uint128.isZero",
+ "uint128.not",
+ "uint128.or",
+ "uint128.subOne",
+ "uint128.xor",
+ }
+ switch runtime.GOARCH {
+ case "amd64", "arm64":
+ // These don't inline on 32-bit.
+ wantInlinable = append(wantInlinable,
+ "u64CommonPrefixLen",
+ "uint128.commonPrefixLen",
+ "Addr.Next",
+ "Addr.Prev",
+ )
+ }
+
+ for _, want := range wantInlinable {
+ if !got[want] {
+ t.Errorf("%q is no longer inlinable", want)
+ continue
+ }
+ delete(got, want)
+ }
+ for sym := range got {
+ if strings.Contains(sym, ".func") {
+ continue
+ }
+ t.Logf("not in expected set, but also inlinable: %q", sym)
+
+ }
+}
diff --git a/libgo/go/net/netip/leaf_alts.go b/libgo/go/net/netip/leaf_alts.go
new file mode 100644
index 0000000..70513ab
--- /dev/null
+++ b/libgo/go/net/netip/leaf_alts.go
@@ -0,0 +1,54 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Stuff that exists in std, but we can't use due to being a dependency
+// of net, for go/build deps_test policy reasons.
+
+package netip
+
+func stringsLastIndexByte(s string, b byte) int {
+ for i := len(s) - 1; i >= 0; i-- {
+ if s[i] == b {
+ return i
+ }
+ }
+ return -1
+}
+
+func beUint64(b []byte) uint64 {
+ _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
+ return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
+ uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
+}
+
+func bePutUint64(b []byte, v uint64) {
+ _ = b[7] // early bounds check to guarantee safety of writes below
+ b[0] = byte(v >> 56)
+ b[1] = byte(v >> 48)
+ b[2] = byte(v >> 40)
+ b[3] = byte(v >> 32)
+ b[4] = byte(v >> 24)
+ b[5] = byte(v >> 16)
+ b[6] = byte(v >> 8)
+ b[7] = byte(v)
+}
+
+func bePutUint32(b []byte, v uint32) {
+ _ = b[3] // early bounds check to guarantee safety of writes below
+ b[0] = byte(v >> 24)
+ b[1] = byte(v >> 16)
+ b[2] = byte(v >> 8)
+ b[3] = byte(v)
+}
+
+func leUint16(b []byte) uint16 {
+ _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
+ return uint16(b[0]) | uint16(b[1])<<8
+}
+
+func lePutUint16(b []byte, v uint16) {
+ _ = b[1] // early bounds check to guarantee safety of writes below
+ b[0] = byte(v)
+ b[1] = byte(v >> 8)
+}
diff --git a/libgo/go/net/netip/netip.go b/libgo/go/net/netip/netip.go
new file mode 100644
index 0000000..591d38a
--- /dev/null
+++ b/libgo/go/net/netip/netip.go
@@ -0,0 +1,1498 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package netip defines an IP address type that's a small value type.
+// Building on that Addr type, the package also defines AddrPort (an
+// IP address and a port), and Prefix (an IP address and a bit length
+// prefix).
+//
+// Compared to the net.IP type, this package's Addr type takes less
+// memory, is immutable, and is comparable (supports == and being a
+// map key).
+package netip
+
+import (
+ "errors"
+ "math"
+ "strconv"
+
+ "internal/bytealg"
+ "internal/intern"
+ "internal/itoa"
+)
+
+// Sizes: (64-bit)
+// net.IP: 24 byte slice header + {4, 16} = 28 to 40 bytes
+// net.IPAddr: 40 byte slice header + {4, 16} = 44 to 56 bytes + zone length
+// netip.Addr: 24 bytes (zone is per-name singleton, shared across all users)
+
+// Addr represents an IPv4 or IPv6 address (with or without a scoped
+// addressing zone), similar to net.IP or net.IPAddr.
+//
+// Unlike net.IP or net.IPAddr, Addr is a comparable value
+// type (it supports == and can be a map key) and is immutable.
+//
+// The zero Addr is not a valid IP address.
+// Addr{} is distinct from both 0.0.0.0 and ::.
+type Addr struct {
+ // addr is the hi and lo bits of an IPv6 address. If z==z4,
+ // hi and lo contain the IPv4-mapped IPv6 address.
+ //
+ // hi and lo are constructed by interpreting a 16-byte IPv6
+ // address as a big-endian 128-bit number. The most significant
+ // bits of that number go into hi, the rest into lo.
+ //
+ // For example, 0011:2233:4455:6677:8899:aabb:ccdd:eeff is stored as:
+ // addr.hi = 0x0011223344556677
+ // addr.lo = 0x8899aabbccddeeff
+ //
+ // We store IPs like this, rather than as [16]byte, because it
+ // turns most operations on IPs into arithmetic and bit-twiddling
+ // operations on 64-bit registers, which is much faster than
+ // bytewise processing.
+ addr uint128
+
+ // z is a combination of the address family and the IPv6 zone.
+ //
+ // nil means invalid IP address (for a zero Addr).
+ // z4 means an IPv4 address.
+ // z6noz means an IPv6 address without a zone.
+ //
+ // Otherwise it's the interned zone name string.
+ z *intern.Value
+}
+
+// z0, z4, and z6noz are sentinel IP.z values.
+// See the IP type's field docs.
+var (
+ z0 = (*intern.Value)(nil)
+ z4 = new(intern.Value)
+ z6noz = new(intern.Value)
+)
+
+// IPv6LinkLocalAllNodes returns the IPv6 link-local all nodes multicast
+// address ff02::1.
+func IPv6LinkLocalAllNodes() Addr { return AddrFrom16([16]byte{0: 0xff, 1: 0x02, 15: 0x01}) }
+
+// IPv6Unspecified returns the IPv6 unspecified address "::".
+func IPv6Unspecified() Addr { return Addr{z: z6noz} }
+
+// IPv4Unspecified returns the IPv4 unspecified address "0.0.0.0".
+func IPv4Unspecified() Addr { return AddrFrom4([4]byte{}) }
+
+// AddrFrom4 returns the address of the IPv4 address given by the bytes in addr.
+func AddrFrom4(addr [4]byte) Addr {
+ return Addr{
+ addr: uint128{0, 0xffff00000000 | uint64(addr[0])<<24 | uint64(addr[1])<<16 | uint64(addr[2])<<8 | uint64(addr[3])},
+ z: z4,
+ }
+}
+
+// AddrFrom16 returns the IPv6 address given by the bytes in addr.
+// An IPv6-mapped IPv4 address is left as an IPv6 address.
+// (Use Unmap to convert them if needed.)
+func AddrFrom16(addr [16]byte) Addr {
+ return Addr{
+ addr: uint128{
+ beUint64(addr[:8]),
+ beUint64(addr[8:]),
+ },
+ z: z6noz,
+ }
+}
+
+// ipv6Slice is like IPv6Raw, but operates on a 16-byte slice. Assumes
+// slice is 16 bytes, caller must enforce this.
+func ipv6Slice(addr []byte) Addr {
+ return Addr{
+ addr: uint128{
+ beUint64(addr[:8]),
+ beUint64(addr[8:]),
+ },
+ z: z6noz,
+ }
+}
+
+// ParseAddr parses s as an IP address, returning the result. The string
+// s can be in dotted decimal ("192.0.2.1"), IPv6 ("2001:db8::68"),
+// or IPv6 with a scoped addressing zone ("fe80::1cc0:3e8c:119f:c2e1%ens18").
+func ParseAddr(s string) (Addr, error) {
+ for i := 0; i < len(s); i++ {
+ switch s[i] {
+ case '.':
+ return parseIPv4(s)
+ case ':':
+ return parseIPv6(s)
+ case '%':
+ // Assume that this was trying to be an IPv6 address with
+ // a zone specifier, but the address is missing.
+ return Addr{}, parseAddrError{in: s, msg: "missing IPv6 address"}
+ }
+ }
+ return Addr{}, parseAddrError{in: s, msg: "unable to parse IP"}
+}
+
+// MustParseAddr calls ParseAddr(s) and panics on error.
+// It is intended for use in tests with hard-coded strings.
+func MustParseAddr(s string) Addr {
+ ip, err := ParseAddr(s)
+ if err != nil {
+ panic(err)
+ }
+ return ip
+}
+
+type parseAddrError struct {
+ in string // the string given to ParseAddr
+ msg string // an explanation of the parse failure
+ at string // optionally, the unparsed portion of in at which the error occurred.
+}
+
+func (err parseAddrError) Error() string {
+ q := strconv.Quote
+ if err.at != "" {
+ return "ParseAddr(" + q(err.in) + "): " + err.msg + " (at " + q(err.at) + ")"
+ }
+ return "ParseAddr(" + q(err.in) + "): " + err.msg
+}
+
+// parseIPv4 parses s as an IPv4 address (in form "192.168.0.1").
+func parseIPv4(s string) (ip Addr, err error) {
+ var fields [4]uint8
+ var val, pos int
+ var digLen int // number of digits in current octet
+ for i := 0; i < len(s); i++ {
+ if s[i] >= '0' && s[i] <= '9' {
+ if digLen == 1 && val == 0 {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 field has octet with leading zero"}
+ }
+ val = val*10 + int(s[i]) - '0'
+ digLen++
+ if val > 255 {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 field has value >255"}
+ }
+ } else if s[i] == '.' {
+ // .1.2.3
+ // 1.2.3.
+ // 1..2.3
+ if i == 0 || i == len(s)-1 || s[i-1] == '.' {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 field must have at least one digit", at: s[i:]}
+ }
+ // 1.2.3.4.5
+ if pos == 3 {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 address too long"}
+ }
+ fields[pos] = uint8(val)
+ pos++
+ val = 0
+ digLen = 0
+ } else {
+ return Addr{}, parseAddrError{in: s, msg: "unexpected character", at: s[i:]}
+ }
+ }
+ if pos < 3 {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 address too short"}
+ }
+ fields[3] = uint8(val)
+ return AddrFrom4(fields), nil
+}
+
+// parseIPv6 parses s as an IPv6 address (in form "2001:db8::68").
+func parseIPv6(in string) (Addr, error) {
+ s := in
+
+ // Split off the zone right from the start. Yes it's a second scan
+ // of the string, but trying to handle it inline makes a bunch of
+ // other inner loop conditionals more expensive, and it ends up
+ // being slower.
+ zone := ""
+ i := bytealg.IndexByteString(s, '%')
+ if i != -1 {
+ s, zone = s[:i], s[i+1:]
+ if zone == "" {
+ // Not allowed to have an empty zone if explicitly specified.
+ return Addr{}, parseAddrError{in: in, msg: "zone must be a non-empty string"}
+ }
+ }
+
+ var ip [16]byte
+ ellipsis := -1 // position of ellipsis in ip
+
+ // Might have leading ellipsis
+ if len(s) >= 2 && s[0] == ':' && s[1] == ':' {
+ ellipsis = 0
+ s = s[2:]
+ // Might be only ellipsis
+ if len(s) == 0 {
+ return IPv6Unspecified().WithZone(zone), nil
+ }
+ }
+
+ // Loop, parsing hex numbers followed by colon.
+ i = 0
+ for i < 16 {
+ // Hex number. Similar to parseIPv4, inlining the hex number
+ // parsing yields a significant performance increase.
+ off := 0
+ acc := uint32(0)
+ for ; off < len(s); off++ {
+ c := s[off]
+ if c >= '0' && c <= '9' {
+ acc = (acc << 4) + uint32(c-'0')
+ } else if c >= 'a' && c <= 'f' {
+ acc = (acc << 4) + uint32(c-'a'+10)
+ } else if c >= 'A' && c <= 'F' {
+ acc = (acc << 4) + uint32(c-'A'+10)
+ } else {
+ break
+ }
+ if acc > math.MaxUint16 {
+ // Overflow, fail.
+ return Addr{}, parseAddrError{in: in, msg: "IPv6 field has value >=2^16", at: s}
+ }
+ }
+ if off == 0 {
+ // No digits found, fail.
+ return Addr{}, parseAddrError{in: in, msg: "each colon-separated field must have at least one digit", at: s}
+ }
+
+ // If followed by dot, might be in trailing IPv4.
+ if off < len(s) && s[off] == '.' {
+ if ellipsis < 0 && i != 12 {
+ // Not the right place.
+ return Addr{}, parseAddrError{in: in, msg: "embedded IPv4 address must replace the final 2 fields of the address", at: s}
+ }
+ if i+4 > 16 {
+ // Not enough room.
+ return Addr{}, parseAddrError{in: in, msg: "too many hex fields to fit an embedded IPv4 at the end of the address", at: s}
+ }
+ // TODO: could make this a bit faster by having a helper
+ // that parses to a [4]byte, and have both parseIPv4 and
+ // parseIPv6 use it.
+ ip4, err := parseIPv4(s)
+ if err != nil {
+ return Addr{}, parseAddrError{in: in, msg: err.Error(), at: s}
+ }
+ ip[i] = ip4.v4(0)
+ ip[i+1] = ip4.v4(1)
+ ip[i+2] = ip4.v4(2)
+ ip[i+3] = ip4.v4(3)
+ s = ""
+ i += 4
+ break
+ }
+
+ // Save this 16-bit chunk.
+ ip[i] = byte(acc >> 8)
+ ip[i+1] = byte(acc)
+ i += 2
+
+ // Stop at end of string.
+ s = s[off:]
+ if len(s) == 0 {
+ break
+ }
+
+ // Otherwise must be followed by colon and more.
+ if s[0] != ':' {
+ return Addr{}, parseAddrError{in: in, msg: "unexpected character, want colon", at: s}
+ } else if len(s) == 1 {
+ return Addr{}, parseAddrError{in: in, msg: "colon must be followed by more characters", at: s}
+ }
+ s = s[1:]
+
+ // Look for ellipsis.
+ if s[0] == ':' {
+ if ellipsis >= 0 { // already have one
+ return Addr{}, parseAddrError{in: in, msg: "multiple :: in address", at: s}
+ }
+ ellipsis = i
+ s = s[1:]
+ if len(s) == 0 { // can be at end
+ break
+ }
+ }
+ }
+
+ // Must have used entire string.
+ if len(s) != 0 {
+ return Addr{}, parseAddrError{in: in, msg: "trailing garbage after address", at: s}
+ }
+
+ // If didn't parse enough, expand ellipsis.
+ if i < 16 {
+ if ellipsis < 0 {
+ return Addr{}, parseAddrError{in: in, msg: "address string too short"}
+ }
+ n := 16 - i
+ for j := i - 1; j >= ellipsis; j-- {
+ ip[j+n] = ip[j]
+ }
+ for j := ellipsis + n - 1; j >= ellipsis; j-- {
+ ip[j] = 0
+ }
+ } else if ellipsis >= 0 {
+ // Ellipsis must represent at least one 0 group.
+ return Addr{}, parseAddrError{in: in, msg: "the :: must expand to at least one field of zeros"}
+ }
+ return AddrFrom16(ip).WithZone(zone), nil
+}
+
+// AddrFromSlice parses the 4- or 16-byte byte slice as an IPv4 or IPv6 address.
+// Note that a net.IP can be passed directly as the []byte argument.
+// If slice's length is not 4 or 16, AddrFromSlice returns Addr{}, false.
+func AddrFromSlice(slice []byte) (ip Addr, ok bool) {
+ switch len(slice) {
+ case 4:
+ return AddrFrom4(*(*[4]byte)(slice)), true
+ case 16:
+ return ipv6Slice(slice), true
+ }
+ return Addr{}, false
+}
+
+// v4 returns the i'th byte of ip. If ip is not an IPv4, v4 returns
+// unspecified garbage.
+func (ip Addr) v4(i uint8) uint8 {
+ return uint8(ip.addr.lo >> ((3 - i) * 8))
+}
+
+// v6 returns the i'th byte of ip. If ip is an IPv4 address, this
+// accesses the IPv4-mapped IPv6 address form of the IP.
+func (ip Addr) v6(i uint8) uint8 {
+ return uint8(*(ip.addr.halves()[(i/8)%2]) >> ((7 - i%8) * 8))
+}
+
+// v6u16 returns the i'th 16-bit word of ip. If ip is an IPv4 address,
+// this accesses the IPv4-mapped IPv6 address form of the IP.
+func (ip Addr) v6u16(i uint8) uint16 {
+ return uint16(*(ip.addr.halves()[(i/4)%2]) >> ((3 - i%4) * 16))
+}
+
+// isZero reports whether ip is the zero value of the IP type.
+// The zero value is not a valid IP address of any type.
+//
+// Note that "0.0.0.0" and "::" are not the zero value. Use IsUnspecified to
+// check for these values instead.
+func (ip Addr) isZero() bool {
+ // Faster than comparing ip == Addr{}, but effectively equivalent,
+ // as there's no way to make an IP with a nil z from this package.
+ return ip.z == z0
+}
+
+// IsValid reports whether the Addr is an initialized address (not the zero Addr).
+//
+// Note that "0.0.0.0" and "::" are both valid values.
+func (ip Addr) IsValid() bool { return ip.z != z0 }
+
+// BitLen returns the number of bits in the IP address:
+// 128 for IPv6, 32 for IPv4, and 0 for the zero Addr.
+//
+// Note that IPv4-mapped IPv6 addresses are considered IPv6 addresses
+// and therefore have bit length 128.
+func (ip Addr) BitLen() int {
+ switch ip.z {
+ case z0:
+ return 0
+ case z4:
+ return 32
+ }
+ return 128
+}
+
+// Zone returns ip's IPv6 scoped addressing zone, if any.
+func (ip Addr) Zone() string {
+ if ip.z == nil {
+ return ""
+ }
+ zone, _ := ip.z.Get().(string)
+ return zone
+}
+
+// Compare returns an integer comparing two IPs.
+// The result will be 0 if ip == ip2, -1 if ip < ip2, and +1 if ip > ip2.
+// The definition of "less than" is the same as the Less method.
+func (ip Addr) Compare(ip2 Addr) int {
+ f1, f2 := ip.BitLen(), ip2.BitLen()
+ if f1 < f2 {
+ return -1
+ }
+ if f1 > f2 {
+ return 1
+ }
+ hi1, hi2 := ip.addr.hi, ip2.addr.hi
+ if hi1 < hi2 {
+ return -1
+ }
+ if hi1 > hi2 {
+ return 1
+ }
+ lo1, lo2 := ip.addr.lo, ip2.addr.lo
+ if lo1 < lo2 {
+ return -1
+ }
+ if lo1 > lo2 {
+ return 1
+ }
+ if ip.Is6() {
+ za, zb := ip.Zone(), ip2.Zone()
+ if za < zb {
+ return -1
+ }
+ if za > zb {
+ return 1
+ }
+ }
+ return 0
+}
+
+// Less reports whether ip sorts before ip2.
+// IP addresses sort first by length, then their address.
+// IPv6 addresses with zones sort just after the same address without a zone.
+func (ip Addr) Less(ip2 Addr) bool { return ip.Compare(ip2) == -1 }
+
+func (ip Addr) lessOrEq(ip2 Addr) bool { return ip.Compare(ip2) <= 0 }
+
+// Is4 reports whether ip is an IPv4 address.
+//
+// It returns false for IP4-mapped IPv6 addresses. See IP.Unmap.
+func (ip Addr) Is4() bool {
+ return ip.z == z4
+}
+
+// Is4In6 reports whether ip is an IPv4-mapped IPv6 address.
+func (ip Addr) Is4In6() bool {
+ return ip.Is6() && ip.addr.hi == 0 && ip.addr.lo>>32 == 0xffff
+}
+
+// Is6 reports whether ip is an IPv6 address, including IPv4-mapped
+// IPv6 addresses.
+func (ip Addr) Is6() bool {
+ return ip.z != z0 && ip.z != z4
+}
+
+// Unmap returns ip with any IPv4-mapped IPv6 address prefix removed.
+//
+// That is, if ip is an IPv6 address wrapping an IPv4 adddress, it
+// returns the wrapped IPv4 address. Otherwise it returns ip unmodified.
+func (ip Addr) Unmap() Addr {
+ if ip.Is4In6() {
+ ip.z = z4
+ }
+ return ip
+}
+
+// WithZone returns an IP that's the same as ip but with the provided
+// zone. If zone is empty, the zone is removed. If ip is an IPv4
+// address, WithZone is a no-op and returns ip unchanged.
+func (ip Addr) WithZone(zone string) Addr {
+ if !ip.Is6() {
+ return ip
+ }
+ if zone == "" {
+ ip.z = z6noz
+ return ip
+ }
+ ip.z = intern.GetByString(zone)
+ return ip
+}
+
+// withoutZone unconditionally strips the zone from IP.
+// It's similar to WithZone, but small enough to be inlinable.
+func (ip Addr) withoutZone() Addr {
+ if !ip.Is6() {
+ return ip
+ }
+ ip.z = z6noz
+ return ip
+}
+
+// hasZone reports whether IP has an IPv6 zone.
+func (ip Addr) hasZone() bool {
+ return ip.z != z0 && ip.z != z4 && ip.z != z6noz
+}
+
+// IsLinkLocalUnicast reports whether ip is a link-local unicast address.
+func (ip Addr) IsLinkLocalUnicast() bool {
+ // Dynamic Configuration of IPv4 Link-Local Addresses
+ // https://datatracker.ietf.org/doc/html/rfc3927#section-2.1
+ if ip.Is4() {
+ return ip.v4(0) == 169 && ip.v4(1) == 254
+ }
+ // IP Version 6 Addressing Architecture (2.4 Address Type Identification)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4
+ if ip.Is6() {
+ return ip.v6u16(0)&0xffc0 == 0xfe80
+ }
+ return false // zero value
+}
+
+// IsLoopback reports whether ip is a loopback address.
+func (ip Addr) IsLoopback() bool {
+ // Requirements for Internet Hosts -- Communication Layers (3.2.1.3 Addressing)
+ // https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.1.3
+ if ip.Is4() {
+ return ip.v4(0) == 127
+ }
+ // IP Version 6 Addressing Architecture (2.4 Address Type Identification)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4
+ if ip.Is6() {
+ return ip.addr.hi == 0 && ip.addr.lo == 1
+ }
+ return false // zero value
+}
+
+// IsMulticast reports whether ip is a multicast address.
+func (ip Addr) IsMulticast() bool {
+ // Host Extensions for IP Multicasting (4. HOST GROUP ADDRESSES)
+ // https://datatracker.ietf.org/doc/html/rfc1112#section-4
+ if ip.Is4() {
+ return ip.v4(0)&0xf0 == 0xe0
+ }
+ // IP Version 6 Addressing Architecture (2.4 Address Type Identification)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4
+ if ip.Is6() {
+ return ip.addr.hi>>(64-8) == 0xff // ip.v6(0) == 0xff
+ }
+ return false // zero value
+}
+
+// IsInterfaceLocalMulticast reports whether ip is an IPv6 interface-local
+// multicast address.
+func (ip Addr) IsInterfaceLocalMulticast() bool {
+ // IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1
+ if ip.Is6() {
+ return ip.v6u16(0)&0xff0f == 0xff01
+ }
+ return false // zero value
+}
+
+// IsLinkLocalMulticast reports whether ip is a link-local multicast address.
+func (ip Addr) IsLinkLocalMulticast() bool {
+ // IPv4 Multicast Guidelines (4. Local Network Control Block (224.0.0/24))
+ // https://datatracker.ietf.org/doc/html/rfc5771#section-4
+ if ip.Is4() {
+ return ip.v4(0) == 224 && ip.v4(1) == 0 && ip.v4(2) == 0
+ }
+ // IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1
+ if ip.Is6() {
+ return ip.v6u16(0)&0xff0f == 0xff02
+ }
+ return false // zero value
+}
+
+// IsGlobalUnicast reports whether ip is a global unicast address.
+//
+// It returns true for IPv6 addresses which fall outside of the current
+// IANA-allocated 2000::/3 global unicast space, with the exception of the
+// link-local address space. It also returns true even if ip is in the IPv4
+// private address space or IPv6 unique local address space.
+// It returns false for the zero Addr.
+//
+// For reference, see RFC 1122, RFC 4291, and RFC 4632.
+func (ip Addr) IsGlobalUnicast() bool {
+ if ip.z == z0 {
+ // Invalid or zero-value.
+ return false
+ }
+
+ // Match package net's IsGlobalUnicast logic. Notably private IPv4 addresses
+ // and ULA IPv6 addresses are still considered "global unicast".
+ if ip.Is4() && (ip == IPv4Unspecified() || ip == AddrFrom4([4]byte{255, 255, 255, 255})) {
+ return false
+ }
+
+ return ip != IPv6Unspecified() &&
+ !ip.IsLoopback() &&
+ !ip.IsMulticast() &&
+ !ip.IsLinkLocalUnicast()
+}
+
+// IsPrivate reports whether ip is a private address, according to RFC 1918
+// (IPv4 addresses) and RFC 4193 (IPv6 addresses). That is, it reports whether
+// ip is in 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, or fc00::/7. This is the
+// same as net.IP.IsPrivate.
+func (ip Addr) IsPrivate() bool {
+ // Match the stdlib's IsPrivate logic.
+ if ip.Is4() {
+ // RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
+ // private IPv4 address subnets.
+ return ip.v4(0) == 10 ||
+ (ip.v4(0) == 172 && ip.v4(1)&0xf0 == 16) ||
+ (ip.v4(0) == 192 && ip.v4(1) == 168)
+ }
+
+ if ip.Is6() {
+ // RFC 4193 allocates fc00::/7 as the unique local unicast IPv6 address
+ // subnet.
+ return ip.v6(0)&0xfe == 0xfc
+ }
+
+ return false // zero value
+}
+
+// IsUnspecified reports whether ip is an unspecified address, either the IPv4
+// address "0.0.0.0" or the IPv6 address "::".
+//
+// Note that the zero Addr is not an unspecified address.
+func (ip Addr) IsUnspecified() bool {
+ return ip == IPv4Unspecified() || ip == IPv6Unspecified()
+}
+
+// Prefix keeps only the top b bits of IP, producing a Prefix
+// of the specified length.
+// If ip is a zero Addr, Prefix always returns a zero Prefix and a nil error.
+// Otherwise, if bits is less than zero or greater than ip.BitLen(),
+// Prefix returns an error.
+func (ip Addr) Prefix(b int) (Prefix, error) {
+ if b < 0 {
+ return Prefix{}, errors.New("negative Prefix bits")
+ }
+ effectiveBits := b
+ switch ip.z {
+ case z0:
+ return Prefix{}, nil
+ case z4:
+ if b > 32 {
+ return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv4")
+ }
+ effectiveBits += 96
+ default:
+ if b > 128 {
+ return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv6")
+ }
+ }
+ ip.addr = ip.addr.and(mask6(effectiveBits))
+ return PrefixFrom(ip, b), nil
+}
+
+const (
+ netIPv4len = 4
+ netIPv6len = 16
+)
+
+// As16 returns the IP address in its 16-byte representation.
+// IPv4 addresses are returned in their v6-mapped form.
+// IPv6 addresses with zones are returned without their zone (use the
+// Zone method to get it).
+// The ip zero value returns all zeroes.
+func (ip Addr) As16() (a16 [16]byte) {
+ bePutUint64(a16[:8], ip.addr.hi)
+ bePutUint64(a16[8:], ip.addr.lo)
+ return a16
+}
+
+// As4 returns an IPv4 or IPv4-in-IPv6 address in its 4-byte representation.
+// If ip is the zero Addr or an IPv6 address, As4 panics.
+// Note that 0.0.0.0 is not the zero Addr.
+func (ip Addr) As4() (a4 [4]byte) {
+ if ip.z == z4 || ip.Is4In6() {
+ bePutUint32(a4[:], uint32(ip.addr.lo))
+ return a4
+ }
+ if ip.z == z0 {
+ panic("As4 called on IP zero value")
+ }
+ panic("As4 called on IPv6 address")
+}
+
+// AsSlice returns an IPv4 or IPv6 address in its respective 4-byte or 16-byte representation.
+func (ip Addr) AsSlice() []byte {
+ switch ip.z {
+ case z0:
+ return nil
+ case z4:
+ var ret [4]byte
+ bePutUint32(ret[:], uint32(ip.addr.lo))
+ return ret[:]
+ default:
+ var ret [16]byte
+ bePutUint64(ret[:8], ip.addr.hi)
+ bePutUint64(ret[8:], ip.addr.lo)
+ return ret[:]
+ }
+}
+
+// Next returns the address following ip.
+// If there is none, it returns the zero Addr.
+func (ip Addr) Next() Addr {
+ ip.addr = ip.addr.addOne()
+ if ip.Is4() {
+ if uint32(ip.addr.lo) == 0 {
+ // Overflowed.
+ return Addr{}
+ }
+ } else {
+ if ip.addr.isZero() {
+ // Overflowed
+ return Addr{}
+ }
+ }
+ return ip
+}
+
+// Prev returns the IP before ip.
+// If there is none, it returns the IP zero value.
+func (ip Addr) Prev() Addr {
+ if ip.Is4() {
+ if uint32(ip.addr.lo) == 0 {
+ return Addr{}
+ }
+ } else if ip.addr.isZero() {
+ return Addr{}
+ }
+ ip.addr = ip.addr.subOne()
+ return ip
+}
+
+// String returns the string form of the IP address ip.
+// It returns one of 5 forms:
+//
+// - "invalid IP", if ip is the zero Addr
+// - IPv4 dotted decimal ("192.0.2.1")
+// - IPv6 ("2001:db8::1")
+// - "::ffff:1.2.3.4" (if Is4In6)
+// - IPv6 with zone ("fe80:db8::1%eth0")
+//
+// Note that unlike package net's IP.String method,
+// IP4-mapped IPv6 addresses format with a "::ffff:"
+// prefix before the dotted quad.
+func (ip Addr) String() string {
+ switch ip.z {
+ case z0:
+ return "invalid IP"
+ case z4:
+ return ip.string4()
+ default:
+ if ip.Is4In6() {
+ // TODO(bradfitz): this could alloc less.
+ if z := ip.Zone(); z != "" {
+ return "::ffff:" + ip.Unmap().String() + "%" + z
+ } else {
+ return "::ffff:" + ip.Unmap().String()
+ }
+ }
+ return ip.string6()
+ }
+}
+
+// AppendTo appends a text encoding of ip,
+// as generated by MarshalText,
+// to b and returns the extended buffer.
+func (ip Addr) AppendTo(b []byte) []byte {
+ switch ip.z {
+ case z0:
+ return b
+ case z4:
+ return ip.appendTo4(b)
+ default:
+ if ip.Is4In6() {
+ b = append(b, "::ffff:"...)
+ b = ip.Unmap().appendTo4(b)
+ if z := ip.Zone(); z != "" {
+ b = append(b, '%')
+ b = append(b, z...)
+ }
+ return b
+ }
+ return ip.appendTo6(b)
+ }
+}
+
+// digits is a string of the hex digits from 0 to f. It's used in
+// appendDecimal and appendHex to format IP addresses.
+const digits = "0123456789abcdef"
+
+// appendDecimal appends the decimal string representation of x to b.
+func appendDecimal(b []byte, x uint8) []byte {
+ // Using this function rather than strconv.AppendUint makes IPv4
+ // string building 2x faster.
+
+ if x >= 100 {
+ b = append(b, digits[x/100])
+ }
+ if x >= 10 {
+ b = append(b, digits[x/10%10])
+ }
+ return append(b, digits[x%10])
+}
+
+// appendHex appends the hex string representation of x to b.
+func appendHex(b []byte, x uint16) []byte {
+ // Using this function rather than strconv.AppendUint makes IPv6
+ // string building 2x faster.
+
+ if x >= 0x1000 {
+ b = append(b, digits[x>>12])
+ }
+ if x >= 0x100 {
+ b = append(b, digits[x>>8&0xf])
+ }
+ if x >= 0x10 {
+ b = append(b, digits[x>>4&0xf])
+ }
+ return append(b, digits[x&0xf])
+}
+
+// appendHexPad appends the fully padded hex string representation of x to b.
+func appendHexPad(b []byte, x uint16) []byte {
+ return append(b, digits[x>>12], digits[x>>8&0xf], digits[x>>4&0xf], digits[x&0xf])
+}
+
+func (ip Addr) string4() string {
+ const max = len("255.255.255.255")
+ ret := make([]byte, 0, max)
+ ret = ip.appendTo4(ret)
+ return string(ret)
+}
+
+func (ip Addr) appendTo4(ret []byte) []byte {
+ ret = appendDecimal(ret, ip.v4(0))
+ ret = append(ret, '.')
+ ret = appendDecimal(ret, ip.v4(1))
+ ret = append(ret, '.')
+ ret = appendDecimal(ret, ip.v4(2))
+ ret = append(ret, '.')
+ ret = appendDecimal(ret, ip.v4(3))
+ return ret
+}
+
+// string6 formats ip in IPv6 textual representation. It follows the
+// guidelines in section 4 of RFC 5952
+// (https://tools.ietf.org/html/rfc5952#section-4): no unnecessary
+// zeros, use :: to elide the longest run of zeros, and don't use ::
+// to compact a single zero field.
+func (ip Addr) string6() string {
+ // Use a zone with a "plausibly long" name, so that most zone-ful
+ // IP addresses won't require additional allocation.
+ //
+ // The compiler does a cool optimization here, where ret ends up
+ // stack-allocated and so the only allocation this function does
+ // is to construct the returned string. As such, it's okay to be a
+ // bit greedy here, size-wise.
+ const max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0")
+ ret := make([]byte, 0, max)
+ ret = ip.appendTo6(ret)
+ return string(ret)
+}
+
+func (ip Addr) appendTo6(ret []byte) []byte {
+ zeroStart, zeroEnd := uint8(255), uint8(255)
+ for i := uint8(0); i < 8; i++ {
+ j := i
+ for j < 8 && ip.v6u16(j) == 0 {
+ j++
+ }
+ if l := j - i; l >= 2 && l > zeroEnd-zeroStart {
+ zeroStart, zeroEnd = i, j
+ }
+ }
+
+ for i := uint8(0); i < 8; i++ {
+ if i == zeroStart {
+ ret = append(ret, ':', ':')
+ i = zeroEnd
+ if i >= 8 {
+ break
+ }
+ } else if i > 0 {
+ ret = append(ret, ':')
+ }
+
+ ret = appendHex(ret, ip.v6u16(i))
+ }
+
+ if ip.z != z6noz {
+ ret = append(ret, '%')
+ ret = append(ret, ip.Zone()...)
+ }
+ return ret
+}
+
+// StringExpanded is like String but IPv6 addresses are expanded with leading
+// zeroes and no "::" compression. For example, "2001:db8::1" becomes
+// "2001:0db8:0000:0000:0000:0000:0000:0001".
+func (ip Addr) StringExpanded() string {
+ switch ip.z {
+ case z0, z4:
+ return ip.String()
+ }
+
+ const size = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
+ ret := make([]byte, 0, size)
+ for i := uint8(0); i < 8; i++ {
+ if i > 0 {
+ ret = append(ret, ':')
+ }
+
+ ret = appendHexPad(ret, ip.v6u16(i))
+ }
+
+ if ip.z != z6noz {
+ // The addition of a zone will cause a second allocation, but when there
+ // is no zone the ret slice will be stack allocated.
+ ret = append(ret, '%')
+ ret = append(ret, ip.Zone()...)
+ }
+ return string(ret)
+}
+
+// MarshalText implements the encoding.TextMarshaler interface,
+// The encoding is the same as returned by String, with one exception:
+// If ip is the zero Addr, the encoding is the empty string.
+func (ip Addr) MarshalText() ([]byte, error) {
+ switch ip.z {
+ case z0:
+ return []byte(""), nil
+ case z4:
+ max := len("255.255.255.255")
+ b := make([]byte, 0, max)
+ return ip.appendTo4(b), nil
+ default:
+ max := len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0")
+ b := make([]byte, 0, max)
+ if ip.Is4In6() {
+ b = append(b, "::ffff:"...)
+ b = ip.Unmap().appendTo4(b)
+ if z := ip.Zone(); z != "" {
+ b = append(b, '%')
+ b = append(b, z...)
+ }
+ return b, nil
+ }
+ return ip.appendTo6(b), nil
+ }
+
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler interface.
+// The IP address is expected in a form accepted by ParseAddr.
+//
+// If text is empty, UnmarshalText sets *ip to the zero Addr and
+// returns no error.
+func (ip *Addr) UnmarshalText(text []byte) error {
+ if len(text) == 0 {
+ *ip = Addr{}
+ return nil
+ }
+ var err error
+ *ip, err = ParseAddr(string(text))
+ return err
+}
+
+func (ip Addr) marshalBinaryWithTrailingBytes(trailingBytes int) []byte {
+ var b []byte
+ switch ip.z {
+ case z0:
+ b = make([]byte, trailingBytes)
+ case z4:
+ b = make([]byte, 4+trailingBytes)
+ bePutUint32(b, uint32(ip.addr.lo))
+ default:
+ z := ip.Zone()
+ b = make([]byte, 16+len(z)+trailingBytes)
+ bePutUint64(b[:8], ip.addr.hi)
+ bePutUint64(b[8:], ip.addr.lo)
+ copy(b[16:], z)
+ }
+ return b
+}
+
+// MarshalBinary implements the encoding.BinaryMarshaler interface.
+// It returns a zero-length slice for the zero Addr,
+// the 4-byte form for an IPv4 address,
+// and the 16-byte form with zone appended for an IPv6 address.
+func (ip Addr) MarshalBinary() ([]byte, error) {
+ return ip.marshalBinaryWithTrailingBytes(0), nil
+}
+
+// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
+// It expects data in the form generated by MarshalBinary.
+func (ip *Addr) UnmarshalBinary(b []byte) error {
+ n := len(b)
+ switch {
+ case n == 0:
+ *ip = Addr{}
+ return nil
+ case n == 4:
+ *ip = AddrFrom4(*(*[4]byte)(b))
+ return nil
+ case n == 16:
+ *ip = ipv6Slice(b)
+ return nil
+ case n > 16:
+ *ip = ipv6Slice(b[:16]).WithZone(string(b[16:]))
+ return nil
+ }
+ return errors.New("unexpected slice size")
+}
+
+// AddrPort is an IP and a port number.
+type AddrPort struct {
+ ip Addr
+ port uint16
+}
+
+// AddrPortFrom returns an AddrPort with the provided IP and port.
+// It does not allocate.
+func AddrPortFrom(ip Addr, port uint16) AddrPort { return AddrPort{ip: ip, port: port} }
+
+// Addr returns p's IP address.
+func (p AddrPort) Addr() Addr { return p.ip }
+
+// Port returns p's port.
+func (p AddrPort) Port() uint16 { return p.port }
+
+// splitAddrPort splits s into an IP address string and a port
+// string. It splits strings shaped like "foo:bar" or "[foo]:bar",
+// without further validating the substrings. v6 indicates whether the
+// ip string should parse as an IPv6 address or an IPv4 address, in
+// order for s to be a valid ip:port string.
+func splitAddrPort(s string) (ip, port string, v6 bool, err error) {
+ i := stringsLastIndexByte(s, ':')
+ if i == -1 {
+ return "", "", false, errors.New("not an ip:port")
+ }
+
+ ip, port = s[:i], s[i+1:]
+ if len(ip) == 0 {
+ return "", "", false, errors.New("no IP")
+ }
+ if len(port) == 0 {
+ return "", "", false, errors.New("no port")
+ }
+ if ip[0] == '[' {
+ if len(ip) < 2 || ip[len(ip)-1] != ']' {
+ return "", "", false, errors.New("missing ]")
+ }
+ ip = ip[1 : len(ip)-1]
+ v6 = true
+ }
+
+ return ip, port, v6, nil
+}
+
+// ParseAddrPort parses s as an AddrPort.
+//
+// It doesn't do any name resolution: both the address and the port
+// must be numeric.
+func ParseAddrPort(s string) (AddrPort, error) {
+ var ipp AddrPort
+ ip, port, v6, err := splitAddrPort(s)
+ if err != nil {
+ return ipp, err
+ }
+ port16, err := strconv.ParseUint(port, 10, 16)
+ if err != nil {
+ return ipp, errors.New("invalid port " + strconv.Quote(port) + " parsing " + strconv.Quote(s))
+ }
+ ipp.port = uint16(port16)
+ ipp.ip, err = ParseAddr(ip)
+ if err != nil {
+ return AddrPort{}, err
+ }
+ if v6 && ipp.ip.Is4() {
+ return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", square brackets can only be used with IPv6 addresses")
+ } else if !v6 && ipp.ip.Is6() {
+ return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", IPv6 addresses must be surrounded by square brackets")
+ }
+ return ipp, nil
+}
+
+// MustParseAddrPort calls ParseAddrPort(s) and panics on error.
+// It is intended for use in tests with hard-coded strings.
+func MustParseAddrPort(s string) AddrPort {
+ ip, err := ParseAddrPort(s)
+ if err != nil {
+ panic(err)
+ }
+ return ip
+}
+
+// isZero reports whether p is the zero AddrPort.
+func (p AddrPort) isZero() bool { return p == AddrPort{} }
+
+// IsValid reports whether p.IP() is valid.
+// All ports are valid, including zero.
+func (p AddrPort) IsValid() bool { return p.ip.IsValid() }
+
+func (p AddrPort) String() string {
+ switch p.ip.z {
+ case z0:
+ return "invalid AddrPort"
+ case z4:
+ a := p.ip.As4()
+ buf := make([]byte, 0, 21)
+ for i := range a {
+ buf = strconv.AppendUint(buf, uint64(a[i]), 10)
+ buf = append(buf, "...:"[i])
+ }
+ buf = strconv.AppendUint(buf, uint64(p.port), 10)
+ return string(buf)
+ default:
+ // TODO: this could be more efficient allocation-wise:
+ return joinHostPort(p.ip.String(), itoa.Itoa(int(p.port)))
+ }
+}
+
+func joinHostPort(host, port string) string {
+ // We assume that host is a literal IPv6 address if host has
+ // colons.
+ if bytealg.IndexByteString(host, ':') >= 0 {
+ return "[" + host + "]:" + port
+ }
+ return host + ":" + port
+}
+
+// AppendTo appends a text encoding of p,
+// as generated by MarshalText,
+// to b and returns the extended buffer.
+func (p AddrPort) AppendTo(b []byte) []byte {
+ switch p.ip.z {
+ case z0:
+ return b
+ case z4:
+ b = p.ip.appendTo4(b)
+ default:
+ if p.ip.Is4In6() {
+ b = append(b, "[::ffff:"...)
+ b = p.ip.Unmap().appendTo4(b)
+ if z := p.ip.Zone(); z != "" {
+ b = append(b, '%')
+ b = append(b, z...)
+ }
+ } else {
+ b = append(b, '[')
+ b = p.ip.appendTo6(b)
+ }
+ b = append(b, ']')
+ }
+ b = append(b, ':')
+ b = strconv.AppendInt(b, int64(p.port), 10)
+ return b
+}
+
+// MarshalText implements the encoding.TextMarshaler interface. The
+// encoding is the same as returned by String, with one exception: if
+// p.Addr() is the zero Addr, the encoding is the empty string.
+func (p AddrPort) MarshalText() ([]byte, error) {
+ var max int
+ switch p.ip.z {
+ case z0:
+ case z4:
+ max = len("255.255.255.255:65535")
+ default:
+ max = len("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0]:65535")
+ }
+ b := make([]byte, 0, max)
+ b = p.AppendTo(b)
+ return b, nil
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler
+// interface. The AddrPort is expected in a form
+// generated by MarshalText or accepted by ParseAddrPort.
+func (p *AddrPort) UnmarshalText(text []byte) error {
+ if len(text) == 0 {
+ *p = AddrPort{}
+ return nil
+ }
+ var err error
+ *p, err = ParseAddrPort(string(text))
+ return err
+}
+
+// MarshalBinary implements the encoding.BinaryMarshaler interface.
+// It returns Addr.MarshalBinary with an additional two bytes appended
+// containing the port in little-endian.
+func (p AddrPort) MarshalBinary() ([]byte, error) {
+ b := p.Addr().marshalBinaryWithTrailingBytes(2)
+ lePutUint16(b[len(b)-2:], p.Port())
+ return b, nil
+}
+
+// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
+// It expects data in the form generated by MarshalBinary.
+func (p *AddrPort) UnmarshalBinary(b []byte) error {
+ if len(b) < 2 {
+ return errors.New("unexpected slice size")
+ }
+ var addr Addr
+ err := addr.UnmarshalBinary(b[:len(b)-2])
+ if err != nil {
+ return err
+ }
+ *p = AddrPortFrom(addr, leUint16(b[len(b)-2:]))
+ return nil
+}
+
+// Prefix is an IP address prefix (CIDR) representing an IP network.
+//
+// The first Bits() of Addr() are specified. The remaining bits match any address.
+// The range of Bits() is [0,32] for IPv4 or [0,128] for IPv6.
+type Prefix struct {
+ ip Addr
+
+ // bits is logically a uint8 (storing [0,128]) but also
+ // encodes an "invalid" bit, currently represented by the
+ // invalidPrefixBits sentinel value. It could be packed into
+ // the uint8 more with more complicated expressions in the
+ // accessors, but the extra byte (in padding anyway) doesn't
+ // hurt and simplifies code below.
+ bits int16
+}
+
+// invalidPrefixBits is the Prefix.bits value used when PrefixFrom is
+// outside the range of a uint8. It's returned as the int -1 in the
+// public API.
+const invalidPrefixBits = -1
+
+// PrefixFrom returns a Prefix with the provided IP address and bit
+// prefix length.
+//
+// It does not allocate. Unlike Addr.Prefix, PrefixFrom does not mask
+// off the host bits of ip.
+//
+// If bits is less than zero or greater than ip.BitLen, Prefix.Bits
+// will return an invalid value -1.
+func PrefixFrom(ip Addr, bits int) Prefix {
+ if bits < 0 || bits > ip.BitLen() {
+ bits = invalidPrefixBits
+ }
+ b16 := int16(bits)
+ return Prefix{
+ ip: ip.withoutZone(),
+ bits: b16,
+ }
+}
+
+// Addr returns p's IP address.
+func (p Prefix) Addr() Addr { return p.ip }
+
+// Bits returns p's prefix length.
+//
+// It reports -1 if invalid.
+func (p Prefix) Bits() int { return int(p.bits) }
+
+// IsValid reports whether p.Bits() has a valid range for p.IP().
+// If p.Addr() is the zero Addr, IsValid returns false.
+// Note that if p is the zero Prefix, then p.IsValid() == false.
+func (p Prefix) IsValid() bool { return !p.ip.isZero() && p.bits >= 0 && int(p.bits) <= p.ip.BitLen() }
+
+func (p Prefix) isZero() bool { return p == Prefix{} }
+
+// IsSingleIP reports whether p contains exactly one IP.
+func (p Prefix) IsSingleIP() bool { return p.bits != 0 && int(p.bits) == p.ip.BitLen() }
+
+// ParsePrefix parses s as an IP address prefix.
+// The string can be in the form "192.168.1.0/24" or "2001::db8::/32",
+// the CIDR notation defined in RFC 4632 and RFC 4291.
+//
+// Note that masked address bits are not zeroed. Use Masked for that.
+func ParsePrefix(s string) (Prefix, error) {
+ i := stringsLastIndexByte(s, '/')
+ if i < 0 {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): no '/'")
+ }
+ ip, err := ParseAddr(s[:i])
+ if err != nil {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): " + err.Error())
+ }
+ bitsStr := s[i+1:]
+ bits, err := strconv.Atoi(bitsStr)
+ if err != nil {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + ": bad bits after slash: " + strconv.Quote(bitsStr))
+ }
+ maxBits := 32
+ if ip.Is6() {
+ maxBits = 128
+ }
+ if bits < 0 || bits > maxBits {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + ": prefix length out of range")
+ }
+ return PrefixFrom(ip, bits), nil
+}
+
+// MustParsePrefix calls ParsePrefix(s) and panics on error.
+// It is intended for use in tests with hard-coded strings.
+func MustParsePrefix(s string) Prefix {
+ ip, err := ParsePrefix(s)
+ if err != nil {
+ panic(err)
+ }
+ return ip
+}
+
+// Masked returns p in its canonical form, with all but the high
+// p.Bits() bits of p.Addr() masked off.
+//
+// If p is zero or otherwise invalid, Masked returns the zero Prefix.
+func (p Prefix) Masked() Prefix {
+ if m, err := p.ip.Prefix(int(p.bits)); err == nil {
+ return m
+ }
+ return Prefix{}
+}
+
+// Contains reports whether the network p includes ip.
+//
+// An IPv4 address will not match an IPv6 prefix.
+// A v6-mapped IPv6 address will not match an IPv4 prefix.
+// A zero-value IP will not match any prefix.
+// If ip has an IPv6 zone, Contains returns false,
+// because Prefixes strip zones.
+func (p Prefix) Contains(ip Addr) bool {
+ if !p.IsValid() || ip.hasZone() {
+ return false
+ }
+ if f1, f2 := p.ip.BitLen(), ip.BitLen(); f1 == 0 || f2 == 0 || f1 != f2 {
+ return false
+ }
+ if ip.Is4() {
+ // xor the IP addresses together; mismatched bits are now ones.
+ // Shift away the number of bits we don't care about.
+ // Shifts in Go are more efficient if the compiler can prove
+ // that the shift amount is smaller than the width of the shifted type (64 here).
+ // We know that p.bits is in the range 0..32 because p is Valid;
+ // the compiler doesn't know that, so mask with 63 to help it.
+ // Now truncate to 32 bits, because this is IPv4.
+ // If all the bits we care about are equal, the result will be zero.
+ return uint32((ip.addr.lo^p.ip.addr.lo)>>((32-p.bits)&63)) == 0
+ } else {
+ // xor the IP addresses together.
+ // Mask away the bits we don't care about.
+ // If all the bits we care about are equal, the result will be zero.
+ return ip.addr.xor(p.ip.addr).and(mask6(int(p.bits))).isZero()
+ }
+}
+
+// Overlaps reports whether p and o contain any IP addresses in common.
+//
+// If p and o are of different address families or either have a zero
+// IP, it reports false. Like the Contains method, a prefix with a
+// v6-mapped IPv4 IP is still treated as an IPv6 mask.
+func (p Prefix) Overlaps(o Prefix) bool {
+ if !p.IsValid() || !o.IsValid() {
+ return false
+ }
+ if p == o {
+ return true
+ }
+ if p.ip.Is4() != o.ip.Is4() {
+ return false
+ }
+ var minBits int16
+ if p.bits < o.bits {
+ minBits = p.bits
+ } else {
+ minBits = o.bits
+ }
+ if minBits == 0 {
+ return true
+ }
+ // One of these Prefix calls might look redundant, but we don't require
+ // that p and o values are normalized (via Prefix.Masked) first,
+ // so the Prefix call on the one that's already minBits serves to zero
+ // out any remaining bits in IP.
+ var err error
+ if p, err = p.ip.Prefix(int(minBits)); err != nil {
+ return false
+ }
+ if o, err = o.ip.Prefix(int(minBits)); err != nil {
+ return false
+ }
+ return p.ip == o.ip
+}
+
+// AppendTo appends a text encoding of p,
+// as generated by MarshalText,
+// to b and returns the extended buffer.
+func (p Prefix) AppendTo(b []byte) []byte {
+ if p.isZero() {
+ return b
+ }
+ if !p.IsValid() {
+ return append(b, "invalid Prefix"...)
+ }
+
+ // p.ip is non-nil, because p is valid.
+ if p.ip.z == z4 {
+ b = p.ip.appendTo4(b)
+ } else {
+ if p.ip.Is4In6() {
+ b = append(b, "::ffff:"...)
+ b = p.ip.Unmap().appendTo4(b)
+ } else {
+ b = p.ip.appendTo6(b)
+ }
+ }
+
+ b = append(b, '/')
+ b = appendDecimal(b, uint8(p.bits))
+ return b
+}
+
+// MarshalText implements the encoding.TextMarshaler interface,
+// The encoding is the same as returned by String, with one exception:
+// If p is the zero value, the encoding is the empty string.
+func (p Prefix) MarshalText() ([]byte, error) {
+ var max int
+ switch p.ip.z {
+ case z0:
+ case z4:
+ max = len("255.255.255.255/32")
+ default:
+ max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0/128")
+ }
+ b := make([]byte, 0, max)
+ b = p.AppendTo(b)
+ return b, nil
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler interface.
+// The IP address is expected in a form accepted by ParsePrefix
+// or generated by MarshalText.
+func (p *Prefix) UnmarshalText(text []byte) error {
+ if len(text) == 0 {
+ *p = Prefix{}
+ return nil
+ }
+ var err error
+ *p, err = ParsePrefix(string(text))
+ return err
+}
+
+// MarshalBinary implements the encoding.BinaryMarshaler interface.
+// It returns Addr.MarshalBinary with an additional byte appended
+// containing the prefix bits.
+func (p Prefix) MarshalBinary() ([]byte, error) {
+ b := p.Addr().withoutZone().marshalBinaryWithTrailingBytes(1)
+ b[len(b)-1] = uint8(p.Bits())
+ return b, nil
+}
+
+// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
+// It expects data in the form generated by MarshalBinary.
+func (p *Prefix) UnmarshalBinary(b []byte) error {
+ if len(b) < 1 {
+ return errors.New("unexpected slice size")
+ }
+ var addr Addr
+ err := addr.UnmarshalBinary(b[:len(b)-1])
+ if err != nil {
+ return err
+ }
+ *p = PrefixFrom(addr, int(b[len(b)-1]))
+ return nil
+}
+
+// String returns the CIDR notation of p: "<ip>/<bits>".
+func (p Prefix) String() string {
+ if !p.IsValid() {
+ return "invalid Prefix"
+ }
+ return p.ip.String() + "/" + itoa.Itoa(int(p.bits))
+}
diff --git a/libgo/go/net/netip/netip_pkg_test.go b/libgo/go/net/netip/netip_pkg_test.go
new file mode 100644
index 0000000..f5cd9ee
--- /dev/null
+++ b/libgo/go/net/netip/netip_pkg_test.go
@@ -0,0 +1,359 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import (
+ "bytes"
+ "encoding"
+ "encoding/json"
+ "strings"
+ "testing"
+)
+
+var (
+ mustPrefix = MustParsePrefix
+ mustIP = MustParseAddr
+)
+
+func TestPrefixValid(t *testing.T) {
+ v4 := MustParseAddr("1.2.3.4")
+ v6 := MustParseAddr("::1")
+ tests := []struct {
+ ipp Prefix
+ want bool
+ }{
+ {Prefix{v4, -2}, false},
+ {Prefix{v4, -1}, false},
+ {Prefix{v4, 0}, true},
+ {Prefix{v4, 32}, true},
+ {Prefix{v4, 33}, false},
+
+ {Prefix{v6, -2}, false},
+ {Prefix{v6, -1}, false},
+ {Prefix{v6, 0}, true},
+ {Prefix{v6, 32}, true},
+ {Prefix{v6, 128}, true},
+ {Prefix{v6, 129}, false},
+
+ {Prefix{Addr{}, -2}, false},
+ {Prefix{Addr{}, -1}, false},
+ {Prefix{Addr{}, 0}, false},
+ {Prefix{Addr{}, 32}, false},
+ {Prefix{Addr{}, 128}, false},
+ }
+ for _, tt := range tests {
+ got := tt.ipp.IsValid()
+ if got != tt.want {
+ t.Errorf("(%v).IsValid() = %v want %v", tt.ipp, got, tt.want)
+ }
+ }
+}
+
+var nextPrevTests = []struct {
+ ip Addr
+ next Addr
+ prev Addr
+}{
+ {mustIP("10.0.0.1"), mustIP("10.0.0.2"), mustIP("10.0.0.0")},
+ {mustIP("10.0.0.255"), mustIP("10.0.1.0"), mustIP("10.0.0.254")},
+ {mustIP("127.0.0.1"), mustIP("127.0.0.2"), mustIP("127.0.0.0")},
+ {mustIP("254.255.255.255"), mustIP("255.0.0.0"), mustIP("254.255.255.254")},
+ {mustIP("255.255.255.255"), Addr{}, mustIP("255.255.255.254")},
+ {mustIP("0.0.0.0"), mustIP("0.0.0.1"), Addr{}},
+ {mustIP("::"), mustIP("::1"), Addr{}},
+ {mustIP("::%x"), mustIP("::1%x"), Addr{}},
+ {mustIP("::1"), mustIP("::2"), mustIP("::")},
+ {mustIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"), Addr{}, mustIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe")},
+}
+
+func TestIPNextPrev(t *testing.T) {
+ doNextPrev(t)
+
+ for _, ip := range []Addr{
+ mustIP("0.0.0.0"),
+ mustIP("::"),
+ } {
+ got := ip.Prev()
+ if !got.isZero() {
+ t.Errorf("IP(%v).Prev = %v; want zero", ip, got)
+ }
+ }
+
+ var allFF [16]byte
+ for i := range allFF {
+ allFF[i] = 0xff
+ }
+
+ for _, ip := range []Addr{
+ mustIP("255.255.255.255"),
+ AddrFrom16(allFF),
+ } {
+ got := ip.Next()
+ if !got.isZero() {
+ t.Errorf("IP(%v).Next = %v; want zero", ip, got)
+ }
+ }
+}
+
+func BenchmarkIPNextPrev(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ doNextPrev(b)
+ }
+}
+
+func doNextPrev(t testing.TB) {
+ for _, tt := range nextPrevTests {
+ gnext, gprev := tt.ip.Next(), tt.ip.Prev()
+ if gnext != tt.next {
+ t.Errorf("IP(%v).Next = %v; want %v", tt.ip, gnext, tt.next)
+ }
+ if gprev != tt.prev {
+ t.Errorf("IP(%v).Prev = %v; want %v", tt.ip, gprev, tt.prev)
+ }
+ if !tt.ip.Next().isZero() && tt.ip.Next().Prev() != tt.ip {
+ t.Errorf("IP(%v).Next.Prev = %v; want %v", tt.ip, tt.ip.Next().Prev(), tt.ip)
+ }
+ if !tt.ip.Prev().isZero() && tt.ip.Prev().Next() != tt.ip {
+ t.Errorf("IP(%v).Prev.Next = %v; want %v", tt.ip, tt.ip.Prev().Next(), tt.ip)
+ }
+ }
+}
+
+func TestIPBitLen(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ want int
+ }{
+ {Addr{}, 0},
+ {mustIP("0.0.0.0"), 32},
+ {mustIP("10.0.0.1"), 32},
+ {mustIP("::"), 128},
+ {mustIP("fed0::1"), 128},
+ {mustIP("::ffff:10.0.0.1"), 128},
+ }
+ for _, tt := range tests {
+ got := tt.ip.BitLen()
+ if got != tt.want {
+ t.Errorf("BitLen(%v) = %d; want %d", tt.ip, got, tt.want)
+ }
+ }
+}
+
+func TestPrefixContains(t *testing.T) {
+ tests := []struct {
+ ipp Prefix
+ ip Addr
+ want bool
+ }{
+ {mustPrefix("9.8.7.6/0"), mustIP("9.8.7.6"), true},
+ {mustPrefix("9.8.7.6/16"), mustIP("9.8.7.6"), true},
+ {mustPrefix("9.8.7.6/16"), mustIP("9.8.6.4"), true},
+ {mustPrefix("9.8.7.6/16"), mustIP("9.9.7.6"), false},
+ {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.6"), true},
+ {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.7"), false},
+ {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.7"), false},
+ {mustPrefix("::1/0"), mustIP("::1"), true},
+ {mustPrefix("::1/0"), mustIP("::2"), true},
+ {mustPrefix("::1/127"), mustIP("::1"), true},
+ {mustPrefix("::1/127"), mustIP("::2"), false},
+ {mustPrefix("::1/128"), mustIP("::1"), true},
+ {mustPrefix("::1/127"), mustIP("::2"), false},
+ // zones support
+ {mustPrefix("::1%a/128"), mustIP("::1"), true}, // prefix zones are stripped...
+ {mustPrefix("::1%a/128"), mustIP("::1%a"), false}, // but ip zones are not
+ // invalid IP
+ {mustPrefix("::1/0"), Addr{}, false},
+ {mustPrefix("1.2.3.4/0"), Addr{}, false},
+ // invalid Prefix
+ {Prefix{mustIP("::1"), 129}, mustIP("::1"), false},
+ {Prefix{mustIP("1.2.3.4"), 33}, mustIP("1.2.3.4"), false},
+ {Prefix{Addr{}, 0}, mustIP("1.2.3.4"), false},
+ {Prefix{Addr{}, 32}, mustIP("1.2.3.4"), false},
+ {Prefix{Addr{}, 128}, mustIP("::1"), false},
+ // wrong IP family
+ {mustPrefix("::1/0"), mustIP("1.2.3.4"), false},
+ {mustPrefix("1.2.3.4/0"), mustIP("::1"), false},
+ }
+ for _, tt := range tests {
+ got := tt.ipp.Contains(tt.ip)
+ if got != tt.want {
+ t.Errorf("(%v).Contains(%v) = %v want %v", tt.ipp, tt.ip, got, tt.want)
+ }
+ }
+}
+
+func TestParseIPError(t *testing.T) {
+ tests := []struct {
+ ip string
+ errstr string
+ }{
+ {
+ ip: "localhost",
+ },
+ {
+ ip: "500.0.0.1",
+ errstr: "field has value >255",
+ },
+ {
+ ip: "::gggg%eth0",
+ errstr: "must have at least one digit",
+ },
+ {
+ ip: "fe80::1cc0:3e8c:119f:c2e1%",
+ errstr: "zone must be a non-empty string",
+ },
+ {
+ ip: "%eth0",
+ errstr: "missing IPv6 address",
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.ip, func(t *testing.T) {
+ _, err := ParseAddr(test.ip)
+ if err == nil {
+ t.Fatal("no error")
+ }
+ if _, ok := err.(parseAddrError); !ok {
+ t.Errorf("error type is %T, want parseIPError", err)
+ }
+ if test.errstr == "" {
+ test.errstr = "unable to parse IP"
+ }
+ if got := err.Error(); !strings.Contains(got, test.errstr) {
+ t.Errorf("error is missing substring %q: %s", test.errstr, got)
+ }
+ })
+ }
+}
+
+func TestParseAddrPort(t *testing.T) {
+ tests := []struct {
+ in string
+ want AddrPort
+ wantErr bool
+ }{
+ {in: "1.2.3.4:1234", want: AddrPort{mustIP("1.2.3.4"), 1234}},
+ {in: "1.1.1.1:123456", wantErr: true},
+ {in: "1.1.1.1:-123", wantErr: true},
+ {in: "[::1]:1234", want: AddrPort{mustIP("::1"), 1234}},
+ {in: "[1.2.3.4]:1234", wantErr: true},
+ {in: "fe80::1:1234", wantErr: true},
+ {in: ":0", wantErr: true}, // if we need to parse this form, there should be a separate function that explicitly allows it
+ }
+ for _, test := range tests {
+ t.Run(test.in, func(t *testing.T) {
+ got, err := ParseAddrPort(test.in)
+ if err != nil {
+ if test.wantErr {
+ return
+ }
+ t.Fatal(err)
+ }
+ if got != test.want {
+ t.Errorf("got %v; want %v", got, test.want)
+ }
+ if got.String() != test.in {
+ t.Errorf("String = %q; want %q", got.String(), test.in)
+ }
+ })
+
+ t.Run(test.in+"/AppendTo", func(t *testing.T) {
+ got, err := ParseAddrPort(test.in)
+ if err == nil {
+ testAppendToMarshal(t, got)
+ }
+ })
+
+ // TextMarshal and TextUnmarshal mostly behave like
+ // ParseAddrPort and String. Divergent behavior are handled in
+ // TestAddrPortMarshalUnmarshal.
+ t.Run(test.in+"/Marshal", func(t *testing.T) {
+ var got AddrPort
+ jsin := `"` + test.in + `"`
+ err := json.Unmarshal([]byte(jsin), &got)
+ if err != nil {
+ if test.wantErr {
+ return
+ }
+ t.Fatal(err)
+ }
+ if got != test.want {
+ t.Errorf("got %v; want %v", got, test.want)
+ }
+ gotb, err := json.Marshal(got)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(gotb) != jsin {
+ t.Errorf("Marshal = %q; want %q", string(gotb), jsin)
+ }
+ })
+ }
+}
+
+func TestAddrPortMarshalUnmarshal(t *testing.T) {
+ tests := []struct {
+ in string
+ want AddrPort
+ }{
+ {"", AddrPort{}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.in, func(t *testing.T) {
+ orig := `"` + test.in + `"`
+
+ var ipp AddrPort
+ if err := json.Unmarshal([]byte(orig), &ipp); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ ippb, err := json.Marshal(ipp)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ back := string(ippb)
+ if orig != back {
+ t.Errorf("Marshal = %q; want %q", back, orig)
+ }
+
+ testAppendToMarshal(t, ipp)
+ })
+ }
+}
+
+type appendMarshaler interface {
+ encoding.TextMarshaler
+ AppendTo([]byte) []byte
+}
+
+// testAppendToMarshal tests that x's AppendTo and MarshalText methods yield the same results.
+// x's MarshalText method must not return an error.
+func testAppendToMarshal(t *testing.T, x appendMarshaler) {
+ t.Helper()
+ m, err := x.MarshalText()
+ if err != nil {
+ t.Fatalf("(%v).MarshalText: %v", x, err)
+ }
+ a := make([]byte, 0, len(m))
+ a = x.AppendTo(a)
+ if !bytes.Equal(m, a) {
+ t.Errorf("(%v).MarshalText = %q, (%v).AppendTo = %q", x, m, x, a)
+ }
+}
+
+func TestIPv6Accessor(t *testing.T) {
+ var a [16]byte
+ for i := range a {
+ a[i] = uint8(i) + 1
+ }
+ ip := AddrFrom16(a)
+ for i := range a {
+ if got, want := ip.v6(uint8(i)), uint8(i)+1; got != want {
+ t.Errorf("v6(%v) = %v; want %v", i, got, want)
+ }
+ }
+}
diff --git a/libgo/go/net/netip/netip_test.go b/libgo/go/net/netip/netip_test.go
new file mode 100644
index 0000000..d988864
--- /dev/null
+++ b/libgo/go/net/netip/netip_test.go
@@ -0,0 +1,1974 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip_test
+
+import (
+ "bytes"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "internal/intern"
+ "net"
+ . "net/netip"
+ "reflect"
+ "sort"
+ "strings"
+ "testing"
+)
+
+var long = flag.Bool("long", false, "run long tests")
+
+type uint128 = Uint128
+
+var (
+ mustPrefix = MustParsePrefix
+ mustIP = MustParseAddr
+ mustIPPort = MustParseAddrPort
+)
+
+func TestParseAddr(t *testing.T) {
+ var validIPs = []struct {
+ in string
+ ip Addr // output of ParseAddr()
+ str string // output of String(). If "", use in.
+ wantErr string
+ }{
+ // Basic zero IPv4 address.
+ {
+ in: "0.0.0.0",
+ ip: MkAddr(Mk128(0, 0xffff00000000), Z4),
+ },
+ // Basic non-zero IPv4 address.
+ {
+ in: "192.168.140.255",
+ ip: MkAddr(Mk128(0, 0xffffc0a88cff), Z4),
+ },
+ // IPv4 address in windows-style "print all the digits" form.
+ {
+ in: "010.000.015.001",
+ wantErr: `ParseAddr("010.000.015.001"): IPv4 field has octet with leading zero`,
+ },
+ // IPv4 address with a silly amount of leading zeros.
+ {
+ in: "000001.00000002.00000003.000000004",
+ wantErr: `ParseAddr("000001.00000002.00000003.000000004"): IPv4 field has octet with leading zero`,
+ },
+ // 4-in-6 with octet with leading zero
+ {
+ in: "::ffff:1.2.03.4",
+ wantErr: `ParseAddr("::ffff:1.2.03.4"): ParseAddr("1.2.03.4"): IPv4 field has octet with leading zero (at "1.2.03.4")`,
+ },
+ // Basic zero IPv6 address.
+ {
+ in: "::",
+ ip: MkAddr(Mk128(0, 0), Z6noz),
+ },
+ // Localhost IPv6.
+ {
+ in: "::1",
+ ip: MkAddr(Mk128(0, 1), Z6noz),
+ },
+ // Fully expanded IPv6 address.
+ {
+ in: "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b",
+ ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b430b), Z6noz),
+ },
+ // IPv6 with elided fields in the middle.
+ {
+ in: "fd7a:115c::626b:430b",
+ ip: MkAddr(Mk128(0xfd7a115c00000000, 0x00000000626b430b), Z6noz),
+ },
+ // IPv6 with elided fields at the end.
+ {
+ in: "fd7a:115c:a1e0:ab12:4843:cd96::",
+ ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd9600000000), Z6noz),
+ },
+ // IPv6 with single elided field at the end.
+ {
+ in: "fd7a:115c:a1e0:ab12:4843:cd96:626b::",
+ ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b0000), Z6noz),
+ str: "fd7a:115c:a1e0:ab12:4843:cd96:626b:0",
+ },
+ // IPv6 with single elided field in the middle.
+ {
+ in: "fd7a:115c:a1e0::4843:cd96:626b:430b",
+ ip: MkAddr(Mk128(0xfd7a115ca1e00000, 0x4843cd96626b430b), Z6noz),
+ str: "fd7a:115c:a1e0:0:4843:cd96:626b:430b",
+ },
+ // IPv6 with the trailing 32 bits written as IPv4 dotted decimal. (4in6)
+ {
+ in: "::ffff:192.168.140.255",
+ ip: MkAddr(Mk128(0, 0x0000ffffc0a88cff), Z6noz),
+ str: "::ffff:192.168.140.255",
+ },
+ // IPv6 with a zone specifier.
+ {
+ in: "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b%eth0",
+ ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b430b), intern.Get("eth0")),
+ },
+ // IPv6 with dotted decimal and zone specifier.
+ {
+ in: "1:2::ffff:192.168.140.255%eth1",
+ ip: MkAddr(Mk128(0x0001000200000000, 0x0000ffffc0a88cff), intern.Get("eth1")),
+ str: "1:2::ffff:c0a8:8cff%eth1",
+ },
+ // 4-in-6 with zone
+ {
+ in: "::ffff:192.168.140.255%eth1",
+ ip: MkAddr(Mk128(0, 0x0000ffffc0a88cff), intern.Get("eth1")),
+ str: "::ffff:192.168.140.255%eth1",
+ },
+ // IPv6 with capital letters.
+ {
+ in: "FD9E:1A04:F01D::1",
+ ip: MkAddr(Mk128(0xfd9e1a04f01d0000, 0x1), Z6noz),
+ str: "fd9e:1a04:f01d::1",
+ },
+ }
+
+ for _, test := range validIPs {
+ t.Run(test.in, func(t *testing.T) {
+ got, err := ParseAddr(test.in)
+ if err != nil {
+ if err.Error() == test.wantErr {
+ return
+ }
+ t.Fatal(err)
+ }
+ if test.wantErr != "" {
+ t.Fatalf("wanted error %q; got none", test.wantErr)
+ }
+ if got != test.ip {
+ t.Errorf("got %#v, want %#v", got, test.ip)
+ }
+
+ // Check that ParseAddr is a pure function.
+ got2, err := ParseAddr(test.in)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != got2 {
+ t.Errorf("ParseAddr(%q) got 2 different results: %#v, %#v", test.in, got, got2)
+ }
+
+ // Check that ParseAddr(ip.String()) is the identity function.
+ s := got.String()
+ got3, err := ParseAddr(s)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != got3 {
+ t.Errorf("ParseAddr(%q) != ParseAddr(ParseIP(%q).String()). Got %#v, want %#v", test.in, test.in, got3, got)
+ }
+
+ // Check that the slow-but-readable parser produces the same result.
+ slow, err := parseIPSlow(test.in)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != slow {
+ t.Errorf("ParseAddr(%q) = %#v, parseIPSlow(%q) = %#v", test.in, got, test.in, slow)
+ }
+
+ // Check that the parsed IP formats as expected.
+ s = got.String()
+ wants := test.str
+ if wants == "" {
+ wants = test.in
+ }
+ if s != wants {
+ t.Errorf("ParseAddr(%q).String() got %q, want %q", test.in, s, wants)
+ }
+
+ // Check that AppendTo matches MarshalText.
+ TestAppendToMarshal(t, got)
+
+ // Check that MarshalText/UnmarshalText work similarly to
+ // ParseAddr/String (see TestIPMarshalUnmarshal for
+ // marshal-specific behavior that's not common with
+ // ParseAddr/String).
+ js := `"` + test.in + `"`
+ var jsgot Addr
+ if err := json.Unmarshal([]byte(js), &jsgot); err != nil {
+ t.Fatal(err)
+ }
+ if jsgot != got {
+ t.Errorf("json.Unmarshal(%q) = %#v, want %#v", test.in, jsgot, got)
+ }
+ jsb, err := json.Marshal(jsgot)
+ if err != nil {
+ t.Fatal(err)
+ }
+ jswant := `"` + wants + `"`
+ jsback := string(jsb)
+ if jsback != jswant {
+ t.Errorf("Marshal(Unmarshal(%q)) = %s, want %s", test.in, jsback, jswant)
+ }
+ })
+ }
+
+ var invalidIPs = []string{
+ // Empty string
+ "",
+ // Garbage non-IP
+ "bad",
+ // Single number. Some parsers accept this as an IPv4 address in
+ // big-endian uint32 form, but we don't.
+ "1234",
+ // IPv4 with a zone specifier
+ "1.2.3.4%eth0",
+ // IPv4 field must have at least one digit
+ ".1.2.3",
+ "1.2.3.",
+ "1..2.3",
+ // IPv4 address too long
+ "1.2.3.4.5",
+ // IPv4 in dotted octal form
+ "0300.0250.0214.0377",
+ // IPv4 in dotted hex form
+ "0xc0.0xa8.0x8c.0xff",
+ // IPv4 in class B form
+ "192.168.12345",
+ // IPv4 in class B form, with a small enough number to be
+ // parseable as a regular dotted decimal field.
+ "127.0.1",
+ // IPv4 in class A form
+ "192.1234567",
+ // IPv4 in class A form, with a small enough number to be
+ // parseable as a regular dotted decimal field.
+ "127.1",
+ // IPv4 field has value >255
+ "192.168.300.1",
+ // IPv4 with too many fields
+ "192.168.0.1.5.6",
+ // IPv6 with not enough fields
+ "1:2:3:4:5:6:7",
+ // IPv6 with too many fields
+ "1:2:3:4:5:6:7:8:9",
+ // IPv6 with 8 fields and a :: expander
+ "1:2:3:4::5:6:7:8",
+ // IPv6 with a field bigger than 2b
+ "fe801::1",
+ // IPv6 with non-hex values in field
+ "fe80:tail:scal:e::",
+ // IPv6 with a zone delimiter but no zone.
+ "fe80::1%",
+ // IPv6 (without ellipsis) with too many fields for trailing embedded IPv4.
+ "ffff:ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255",
+ // IPv6 (with ellipsis) with too many fields for trailing embedded IPv4.
+ "ffff::ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255",
+ // IPv6 with invalid embedded IPv4.
+ "::ffff:192.168.140.bad",
+ // IPv6 with multiple ellipsis ::.
+ "fe80::1::1",
+ // IPv6 with invalid non hex/colon character.
+ "fe80:1?:1",
+ // IPv6 with truncated bytes after single colon.
+ "fe80:",
+ }
+
+ for _, s := range invalidIPs {
+ t.Run(s, func(t *testing.T) {
+ got, err := ParseAddr(s)
+ if err == nil {
+ t.Errorf("ParseAddr(%q) = %#v, want error", s, got)
+ }
+
+ slow, err := parseIPSlow(s)
+ if err == nil {
+ t.Errorf("parseIPSlow(%q) = %#v, want error", s, slow)
+ }
+
+ std := net.ParseIP(s)
+ if std != nil {
+ t.Errorf("net.ParseIP(%q) = %#v, want error", s, std)
+ }
+
+ if s == "" {
+ // Don't test unmarshaling of "" here, do it in
+ // IPMarshalUnmarshal.
+ return
+ }
+ var jsgot Addr
+ js := []byte(`"` + s + `"`)
+ if err := json.Unmarshal(js, &jsgot); err == nil {
+ t.Errorf("json.Unmarshal(%q) = %#v, want error", s, jsgot)
+ }
+ })
+ }
+}
+
+func TestIPv4Constructors(t *testing.T) {
+ if AddrFrom4([4]byte{1, 2, 3, 4}) != MustParseAddr("1.2.3.4") {
+ t.Errorf("don't match")
+ }
+}
+
+func TestAddrMarshalUnmarshalBinary(t *testing.T) {
+ tests := []struct {
+ ip string
+ wantSize int
+ }{
+ {"", 0}, // zero IP
+ {"1.2.3.4", 4},
+ {"fd7a:115c:a1e0:ab12:4843:cd96:626b:430b", 16},
+ {"::ffff:c000:0280", 16},
+ {"::ffff:c000:0280%eth0", 20},
+ }
+ for _, tc := range tests {
+ var ip Addr
+ if len(tc.ip) > 0 {
+ ip = mustIP(tc.ip)
+ }
+ b, err := ip.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(b) != tc.wantSize {
+ t.Fatalf("%q encoded to size %d; want %d", tc.ip, len(b), tc.wantSize)
+ }
+ var ip2 Addr
+ if err := ip2.UnmarshalBinary(b); err != nil {
+ t.Fatal(err)
+ }
+ if ip != ip2 {
+ t.Fatalf("got %v; want %v", ip2, ip)
+ }
+ }
+
+ // Cannot unmarshal from unexpected IP length.
+ for _, n := range []int{3, 5} {
+ var ip2 Addr
+ if err := ip2.UnmarshalBinary(bytes.Repeat([]byte{1}, n)); err == nil {
+ t.Fatalf("unmarshaled from unexpected IP length %d", n)
+ }
+ }
+}
+
+func TestAddrPortMarshalTextString(t *testing.T) {
+ tests := []struct {
+ in AddrPort
+ want string
+ }{
+ {mustIPPort("1.2.3.4:80"), "1.2.3.4:80"},
+ {mustIPPort("[1::CAFE]:80"), "[1::cafe]:80"},
+ {mustIPPort("[1::CAFE%en0]:80"), "[1::cafe%en0]:80"},
+ {mustIPPort("[::FFFF:192.168.140.255]:80"), "[::ffff:192.168.140.255]:80"},
+ {mustIPPort("[::FFFF:192.168.140.255%en0]:80"), "[::ffff:192.168.140.255%en0]:80"},
+ }
+ for i, tt := range tests {
+ if got := tt.in.String(); got != tt.want {
+ t.Errorf("%d. for (%v, %v) String = %q; want %q", i, tt.in.Addr(), tt.in.Port(), got, tt.want)
+ }
+ mt, err := tt.in.MarshalText()
+ if err != nil {
+ t.Errorf("%d. for (%v, %v) MarshalText error: %v", i, tt.in.Addr(), tt.in.Port(), err)
+ continue
+ }
+ if string(mt) != tt.want {
+ t.Errorf("%d. for (%v, %v) MarshalText = %q; want %q", i, tt.in.Addr(), tt.in.Port(), mt, tt.want)
+ }
+ }
+}
+
+func TestAddrPortMarshalUnmarshalBinary(t *testing.T) {
+ tests := []struct {
+ ipport string
+ wantSize int
+ }{
+ {"1.2.3.4:51820", 4 + 2},
+ {"[fd7a:115c:a1e0:ab12:4843:cd96:626b:430b]:80", 16 + 2},
+ {"[::ffff:c000:0280]:65535", 16 + 2},
+ {"[::ffff:c000:0280%eth0]:1", 20 + 2},
+ }
+ for _, tc := range tests {
+ var ipport AddrPort
+ if len(tc.ipport) > 0 {
+ ipport = mustIPPort(tc.ipport)
+ }
+ b, err := ipport.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(b) != tc.wantSize {
+ t.Fatalf("%q encoded to size %d; want %d", tc.ipport, len(b), tc.wantSize)
+ }
+ var ipport2 AddrPort
+ if err := ipport2.UnmarshalBinary(b); err != nil {
+ t.Fatal(err)
+ }
+ if ipport != ipport2 {
+ t.Fatalf("got %v; want %v", ipport2, ipport)
+ }
+ }
+
+ // Cannot unmarshal from unexpected lengths.
+ for _, n := range []int{3, 7} {
+ var ipport2 AddrPort
+ if err := ipport2.UnmarshalBinary(bytes.Repeat([]byte{1}, n)); err == nil {
+ t.Fatalf("unmarshaled from unexpected length %d", n)
+ }
+ }
+}
+
+func TestPrefixMarshalTextString(t *testing.T) {
+ tests := []struct {
+ in Prefix
+ want string
+ }{
+ {mustPrefix("1.2.3.4/24"), "1.2.3.4/24"},
+ {mustPrefix("fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118"), "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118"},
+ {mustPrefix("::ffff:c000:0280/96"), "::ffff:192.0.2.128/96"},
+ {mustPrefix("::ffff:c000:0280%eth0/37"), "::ffff:192.0.2.128/37"}, // Zone should be stripped
+ {mustPrefix("::ffff:192.168.140.255/8"), "::ffff:192.168.140.255/8"},
+ }
+ for i, tt := range tests {
+ if got := tt.in.String(); got != tt.want {
+ t.Errorf("%d. for %v String = %q; want %q", i, tt.in, got, tt.want)
+ }
+ mt, err := tt.in.MarshalText()
+ if err != nil {
+ t.Errorf("%d. for %v MarshalText error: %v", i, tt.in, err)
+ continue
+ }
+ if string(mt) != tt.want {
+ t.Errorf("%d. for %v MarshalText = %q; want %q", i, tt.in, mt, tt.want)
+ }
+ }
+}
+
+func TestPrefixMarshalUnmarshalBinary(t *testing.T) {
+ type testCase struct {
+ prefix Prefix
+ wantSize int
+ }
+ tests := []testCase{
+ {mustPrefix("1.2.3.4/24"), 4 + 1},
+ {mustPrefix("fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118"), 16 + 1},
+ {mustPrefix("::ffff:c000:0280/96"), 16 + 1},
+ {mustPrefix("::ffff:c000:0280%eth0/37"), 16 + 1}, // Zone should be stripped
+ }
+ tests = append(tests,
+ testCase{PrefixFrom(tests[0].prefix.Addr(), 33), tests[0].wantSize},
+ testCase{PrefixFrom(tests[1].prefix.Addr(), 129), tests[1].wantSize})
+ for _, tc := range tests {
+ prefix := tc.prefix
+ b, err := prefix.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(b) != tc.wantSize {
+ t.Fatalf("%q encoded to size %d; want %d", tc.prefix, len(b), tc.wantSize)
+ }
+ var prefix2 Prefix
+ if err := prefix2.UnmarshalBinary(b); err != nil {
+ t.Fatal(err)
+ }
+ if prefix != prefix2 {
+ t.Fatalf("got %v; want %v", prefix2, prefix)
+ }
+ }
+
+ // Cannot unmarshal from unexpected lengths.
+ for _, n := range []int{3, 6} {
+ var prefix2 Prefix
+ if err := prefix2.UnmarshalBinary(bytes.Repeat([]byte{1}, n)); err == nil {
+ t.Fatalf("unmarshaled from unexpected length %d", n)
+ }
+ }
+}
+
+func TestAddrMarshalUnmarshal(t *testing.T) {
+ // This only tests the cases where Marshal/Unmarshal diverges from
+ // the behavior of ParseAddr/String. For the rest of the test cases,
+ // see TestParseAddr above.
+ orig := `""`
+ var ip Addr
+ if err := json.Unmarshal([]byte(orig), &ip); err != nil {
+ t.Fatalf("Unmarshal(%q) got error %v", orig, err)
+ }
+ if ip != (Addr{}) {
+ t.Errorf("Unmarshal(%q) is not the zero Addr", orig)
+ }
+
+ jsb, err := json.Marshal(ip)
+ if err != nil {
+ t.Fatalf("Marshal(%v) got error %v", ip, err)
+ }
+ back := string(jsb)
+ if back != orig {
+ t.Errorf("Marshal(Unmarshal(%q)) got %q, want %q", orig, back, orig)
+ }
+}
+
+func TestAddrFrom16(t *testing.T) {
+ tests := []struct {
+ name string
+ in [16]byte
+ want Addr
+ }{
+ {
+ name: "v6-raw",
+ in: [...]byte{15: 1},
+ want: MkAddr(Mk128(0, 1), Z6noz),
+ },
+ {
+ name: "v4-raw",
+ in: [...]byte{10: 0xff, 11: 0xff, 12: 1, 13: 2, 14: 3, 15: 4},
+ want: MkAddr(Mk128(0, 0xffff01020304), Z6noz),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := AddrFrom16(tt.in)
+ if got != tt.want {
+ t.Errorf("got %#v; want %#v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIPProperties(t *testing.T) {
+ var (
+ nilIP Addr
+
+ unicast4 = mustIP("192.0.2.1")
+ unicast6 = mustIP("2001:db8::1")
+ unicastZone6 = mustIP("2001:db8::1%eth0")
+ unicast6Unassigned = mustIP("4000::1") // not in 2000::/3.
+
+ multicast4 = mustIP("224.0.0.1")
+ multicast6 = mustIP("ff02::1")
+ multicastZone6 = mustIP("ff02::1%eth0")
+
+ llu4 = mustIP("169.254.0.1")
+ llu6 = mustIP("fe80::1")
+ llu6Last = mustIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
+ lluZone6 = mustIP("fe80::1%eth0")
+
+ loopback4 = mustIP("127.0.0.1")
+ loopback6 = mustIP("::1")
+
+ ilm6 = mustIP("ff01::1")
+ ilmZone6 = mustIP("ff01::1%eth0")
+
+ private4a = mustIP("10.0.0.1")
+ private4b = mustIP("172.16.0.1")
+ private4c = mustIP("192.168.1.1")
+ private6 = mustIP("fd00::1")
+
+ unspecified4 = AddrFrom4([4]byte{})
+ unspecified6 = IPv6Unspecified()
+ )
+
+ tests := []struct {
+ name string
+ ip Addr
+ globalUnicast bool
+ interfaceLocalMulticast bool
+ linkLocalMulticast bool
+ linkLocalUnicast bool
+ loopback bool
+ multicast bool
+ private bool
+ unspecified bool
+ }{
+ {
+ name: "nil",
+ ip: nilIP,
+ },
+ {
+ name: "unicast v4Addr",
+ ip: unicast4,
+ globalUnicast: true,
+ },
+ {
+ name: "unicast v6Addr",
+ ip: unicast6,
+ globalUnicast: true,
+ },
+ {
+ name: "unicast v6AddrZone",
+ ip: unicastZone6,
+ globalUnicast: true,
+ },
+ {
+ name: "unicast v6Addr unassigned",
+ ip: unicast6Unassigned,
+ globalUnicast: true,
+ },
+ {
+ name: "multicast v4Addr",
+ ip: multicast4,
+ linkLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "multicast v6Addr",
+ ip: multicast6,
+ linkLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "multicast v6AddrZone",
+ ip: multicastZone6,
+ linkLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "link-local unicast v4Addr",
+ ip: llu4,
+ linkLocalUnicast: true,
+ },
+ {
+ name: "link-local unicast v6Addr",
+ ip: llu6,
+ linkLocalUnicast: true,
+ },
+ {
+ name: "link-local unicast v6Addr upper bound",
+ ip: llu6Last,
+ linkLocalUnicast: true,
+ },
+ {
+ name: "link-local unicast v6AddrZone",
+ ip: lluZone6,
+ linkLocalUnicast: true,
+ },
+ {
+ name: "loopback v4Addr",
+ ip: loopback4,
+ loopback: true,
+ },
+ {
+ name: "loopback v6Addr",
+ ip: loopback6,
+ loopback: true,
+ },
+ {
+ name: "interface-local multicast v6Addr",
+ ip: ilm6,
+ interfaceLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "interface-local multicast v6AddrZone",
+ ip: ilmZone6,
+ interfaceLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "private v4Addr 10/8",
+ ip: private4a,
+ globalUnicast: true,
+ private: true,
+ },
+ {
+ name: "private v4Addr 172.16/12",
+ ip: private4b,
+ globalUnicast: true,
+ private: true,
+ },
+ {
+ name: "private v4Addr 192.168/16",
+ ip: private4c,
+ globalUnicast: true,
+ private: true,
+ },
+ {
+ name: "private v6Addr",
+ ip: private6,
+ globalUnicast: true,
+ private: true,
+ },
+ {
+ name: "unspecified v4Addr",
+ ip: unspecified4,
+ unspecified: true,
+ },
+ {
+ name: "unspecified v6Addr",
+ ip: unspecified6,
+ unspecified: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gu := tt.ip.IsGlobalUnicast()
+ if gu != tt.globalUnicast {
+ t.Errorf("IsGlobalUnicast(%v) = %v; want %v", tt.ip, gu, tt.globalUnicast)
+ }
+
+ ilm := tt.ip.IsInterfaceLocalMulticast()
+ if ilm != tt.interfaceLocalMulticast {
+ t.Errorf("IsInterfaceLocalMulticast(%v) = %v; want %v", tt.ip, ilm, tt.interfaceLocalMulticast)
+ }
+
+ llu := tt.ip.IsLinkLocalUnicast()
+ if llu != tt.linkLocalUnicast {
+ t.Errorf("IsLinkLocalUnicast(%v) = %v; want %v", tt.ip, llu, tt.linkLocalUnicast)
+ }
+
+ llm := tt.ip.IsLinkLocalMulticast()
+ if llm != tt.linkLocalMulticast {
+ t.Errorf("IsLinkLocalMulticast(%v) = %v; want %v", tt.ip, llm, tt.linkLocalMulticast)
+ }
+
+ lo := tt.ip.IsLoopback()
+ if lo != tt.loopback {
+ t.Errorf("IsLoopback(%v) = %v; want %v", tt.ip, lo, tt.loopback)
+ }
+
+ multicast := tt.ip.IsMulticast()
+ if multicast != tt.multicast {
+ t.Errorf("IsMulticast(%v) = %v; want %v", tt.ip, multicast, tt.multicast)
+ }
+
+ private := tt.ip.IsPrivate()
+ if private != tt.private {
+ t.Errorf("IsPrivate(%v) = %v; want %v", tt.ip, private, tt.private)
+ }
+
+ unspecified := tt.ip.IsUnspecified()
+ if unspecified != tt.unspecified {
+ t.Errorf("IsUnspecified(%v) = %v; want %v", tt.ip, unspecified, tt.unspecified)
+ }
+ })
+ }
+}
+
+func TestAddrWellKnown(t *testing.T) {
+ tests := []struct {
+ name string
+ ip Addr
+ std net.IP
+ }{
+ {
+ name: "IPv6 link-local all nodes",
+ ip: IPv6LinkLocalAllNodes(),
+ std: net.IPv6linklocalallnodes,
+ },
+ {
+ name: "IPv6 unspecified",
+ ip: IPv6Unspecified(),
+ std: net.IPv6unspecified,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ want := tt.std.String()
+ got := tt.ip.String()
+
+ if got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ })
+ }
+}
+
+func TestLessCompare(t *testing.T) {
+ tests := []struct {
+ a, b Addr
+ want bool
+ }{
+ {Addr{}, Addr{}, false},
+ {Addr{}, mustIP("1.2.3.4"), true},
+ {mustIP("1.2.3.4"), Addr{}, false},
+
+ {mustIP("1.2.3.4"), mustIP("0102:0304::0"), true},
+ {mustIP("0102:0304::0"), mustIP("1.2.3.4"), false},
+ {mustIP("1.2.3.4"), mustIP("1.2.3.4"), false},
+
+ {mustIP("::1"), mustIP("::2"), true},
+ {mustIP("::1"), mustIP("::1%foo"), true},
+ {mustIP("::1%foo"), mustIP("::2"), true},
+ {mustIP("::2"), mustIP("::3"), true},
+
+ {mustIP("::"), mustIP("0.0.0.0"), false},
+ {mustIP("0.0.0.0"), mustIP("::"), true},
+
+ {mustIP("::1%a"), mustIP("::1%b"), true},
+ {mustIP("::1%a"), mustIP("::1%a"), false},
+ {mustIP("::1%b"), mustIP("::1%a"), false},
+ }
+ for _, tt := range tests {
+ got := tt.a.Less(tt.b)
+ if got != tt.want {
+ t.Errorf("Less(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
+ }
+ cmp := tt.a.Compare(tt.b)
+ if got && cmp != -1 {
+ t.Errorf("Less(%q, %q) = true, but Compare = %v (not -1)", tt.a, tt.b, cmp)
+ }
+ if cmp < -1 || cmp > 1 {
+ t.Errorf("bogus Compare return value %v", cmp)
+ }
+ if cmp == 0 && tt.a != tt.b {
+ t.Errorf("Compare(%q, %q) = 0; but not equal", tt.a, tt.b)
+ }
+ if cmp == 1 && !tt.b.Less(tt.a) {
+ t.Errorf("Compare(%q, %q) = 1; but b.Less(a) isn't true", tt.a, tt.b)
+ }
+
+ // Also check inverse.
+ if got == tt.want && got {
+ got2 := tt.b.Less(tt.a)
+ if got2 {
+ t.Errorf("Less(%q, %q) was correctly %v, but so was Less(%q, %q)", tt.a, tt.b, got, tt.b, tt.a)
+ }
+ }
+ }
+
+ // And just sort.
+ values := []Addr{
+ mustIP("::1"),
+ mustIP("::2"),
+ Addr{},
+ mustIP("1.2.3.4"),
+ mustIP("8.8.8.8"),
+ mustIP("::1%foo"),
+ }
+ sort.Slice(values, func(i, j int) bool { return values[i].Less(values[j]) })
+ got := fmt.Sprintf("%s", values)
+ want := `[invalid IP 1.2.3.4 8.8.8.8 ::1 ::1%foo ::2]`
+ if got != want {
+ t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
+ }
+}
+
+func TestIPStringExpanded(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ s string
+ }{
+ {
+ ip: Addr{},
+ s: "invalid IP",
+ },
+ {
+ ip: mustIP("192.0.2.1"),
+ s: "192.0.2.1",
+ },
+ {
+ ip: mustIP("::ffff:192.0.2.1"),
+ s: "0000:0000:0000:0000:0000:ffff:c000:0201",
+ },
+ {
+ ip: mustIP("2001:db8::1"),
+ s: "2001:0db8:0000:0000:0000:0000:0000:0001",
+ },
+ {
+ ip: mustIP("2001:db8::1%eth0"),
+ s: "2001:0db8:0000:0000:0000:0000:0000:0001%eth0",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.ip.String(), func(t *testing.T) {
+ want := tt.s
+ got := tt.ip.StringExpanded()
+
+ if got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ })
+ }
+}
+
+func TestPrefixMasking(t *testing.T) {
+ type subtest struct {
+ ip Addr
+ bits uint8
+ p Prefix
+ ok bool
+ }
+
+ // makeIPv6 produces a set of IPv6 subtests with an optional zone identifier.
+ makeIPv6 := func(zone string) []subtest {
+ if zone != "" {
+ zone = "%" + zone
+ }
+
+ return []subtest{
+ {
+ ip: mustIP(fmt.Sprintf("2001:db8::1%s", zone)),
+ bits: 255,
+ },
+ {
+ ip: mustIP(fmt.Sprintf("2001:db8::1%s", zone)),
+ bits: 32,
+ p: mustPrefix(fmt.Sprintf("2001:db8::%s/32", zone)),
+ ok: true,
+ },
+ {
+ ip: mustIP(fmt.Sprintf("fe80::dead:beef:dead:beef%s", zone)),
+ bits: 96,
+ p: mustPrefix(fmt.Sprintf("fe80::dead:beef:0:0%s/96", zone)),
+ ok: true,
+ },
+ {
+ ip: mustIP(fmt.Sprintf("aaaa::%s", zone)),
+ bits: 4,
+ p: mustPrefix(fmt.Sprintf("a000::%s/4", zone)),
+ ok: true,
+ },
+ {
+ ip: mustIP(fmt.Sprintf("::%s", zone)),
+ bits: 63,
+ p: mustPrefix(fmt.Sprintf("::%s/63", zone)),
+ ok: true,
+ },
+ }
+ }
+
+ tests := []struct {
+ family string
+ subtests []subtest
+ }{
+ {
+ family: "nil",
+ subtests: []subtest{
+ {
+ bits: 255,
+ ok: true,
+ },
+ {
+ bits: 16,
+ ok: true,
+ },
+ },
+ },
+ {
+ family: "IPv4",
+ subtests: []subtest{
+ {
+ ip: mustIP("192.0.2.0"),
+ bits: 255,
+ },
+ {
+ ip: mustIP("192.0.2.0"),
+ bits: 16,
+ p: mustPrefix("192.0.0.0/16"),
+ ok: true,
+ },
+ {
+ ip: mustIP("255.255.255.255"),
+ bits: 20,
+ p: mustPrefix("255.255.240.0/20"),
+ ok: true,
+ },
+ {
+ // Partially masking one byte that contains both
+ // 1s and 0s on either side of the mask limit.
+ ip: mustIP("100.98.156.66"),
+ bits: 10,
+ p: mustPrefix("100.64.0.0/10"),
+ ok: true,
+ },
+ },
+ },
+ {
+ family: "IPv6",
+ subtests: makeIPv6(""),
+ },
+ {
+ family: "IPv6 zone",
+ subtests: makeIPv6("eth0"),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.family, func(t *testing.T) {
+ for _, st := range tt.subtests {
+ t.Run(st.p.String(), func(t *testing.T) {
+ // Ensure st.ip is not mutated.
+ orig := st.ip.String()
+
+ p, err := st.ip.Prefix(int(st.bits))
+ if st.ok && err != nil {
+ t.Fatalf("failed to produce prefix: %v", err)
+ }
+ if !st.ok && err == nil {
+ t.Fatal("expected an error, but none occurred")
+ }
+ if err != nil {
+ t.Logf("err: %v", err)
+ return
+ }
+
+ if !reflect.DeepEqual(p, st.p) {
+ t.Errorf("prefix = %q, want %q", p, st.p)
+ }
+
+ if got := st.ip.String(); got != orig {
+ t.Errorf("IP was mutated: %q, want %q", got, orig)
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestPrefixMarshalUnmarshal(t *testing.T) {
+ tests := []string{
+ "",
+ "1.2.3.4/32",
+ "0.0.0.0/0",
+ "::/0",
+ "::1/128",
+ "2001:db8::/32",
+ }
+
+ for _, s := range tests {
+ t.Run(s, func(t *testing.T) {
+ // Ensure that JSON (and by extension, text) marshaling is
+ // sane by entering quoted input.
+ orig := `"` + s + `"`
+
+ var p Prefix
+ if err := json.Unmarshal([]byte(orig), &p); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ pb, err := json.Marshal(p)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ back := string(pb)
+ if orig != back {
+ t.Errorf("Marshal = %q; want %q", back, orig)
+ }
+ })
+ }
+}
+
+func TestPrefixMarshalUnmarshalZone(t *testing.T) {
+ orig := `"fe80::1cc0:3e8c:119f:c2e1%ens18/128"`
+ unzoned := `"fe80::1cc0:3e8c:119f:c2e1/128"`
+
+ var p Prefix
+ if err := json.Unmarshal([]byte(orig), &p); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ pb, err := json.Marshal(p)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ back := string(pb)
+ if back != unzoned {
+ t.Errorf("Marshal = %q; want %q", back, unzoned)
+ }
+}
+
+func TestPrefixUnmarshalTextNonZero(t *testing.T) {
+ ip := mustPrefix("fe80::/64")
+ if err := ip.UnmarshalText([]byte("xxx")); err == nil {
+ t.Fatal("unmarshaled into non-empty Prefix")
+ }
+}
+
+func TestIs4AndIs6(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ is4 bool
+ is6 bool
+ }{
+ {Addr{}, false, false},
+ {mustIP("1.2.3.4"), true, false},
+ {mustIP("127.0.0.2"), true, false},
+ {mustIP("::1"), false, true},
+ {mustIP("::ffff:192.0.2.128"), false, true},
+ {mustIP("::fffe:c000:0280"), false, true},
+ {mustIP("::1%eth0"), false, true},
+ }
+ for _, tt := range tests {
+ got4 := tt.ip.Is4()
+ if got4 != tt.is4 {
+ t.Errorf("Is4(%q) = %v; want %v", tt.ip, got4, tt.is4)
+ }
+
+ got6 := tt.ip.Is6()
+ if got6 != tt.is6 {
+ t.Errorf("Is6(%q) = %v; want %v", tt.ip, got6, tt.is6)
+ }
+ }
+}
+
+func TestIs4In6(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ want bool
+ wantUnmap Addr
+ }{
+ {Addr{}, false, Addr{}},
+ {mustIP("::ffff:c000:0280"), true, mustIP("192.0.2.128")},
+ {mustIP("::ffff:192.0.2.128"), true, mustIP("192.0.2.128")},
+ {mustIP("::ffff:192.0.2.128%eth0"), true, mustIP("192.0.2.128")},
+ {mustIP("::fffe:c000:0280"), false, mustIP("::fffe:c000:0280")},
+ {mustIP("::ffff:127.1.2.3"), true, mustIP("127.1.2.3")},
+ {mustIP("::ffff:7f01:0203"), true, mustIP("127.1.2.3")},
+ {mustIP("0:0:0:0:0000:ffff:127.1.2.3"), true, mustIP("127.1.2.3")},
+ {mustIP("0:0:0:0:000000:ffff:127.1.2.3"), true, mustIP("127.1.2.3")},
+ {mustIP("0:0:0:0::ffff:127.1.2.3"), true, mustIP("127.1.2.3")},
+ {mustIP("::1"), false, mustIP("::1")},
+ {mustIP("1.2.3.4"), false, mustIP("1.2.3.4")},
+ }
+ for _, tt := range tests {
+ got := tt.ip.Is4In6()
+ if got != tt.want {
+ t.Errorf("Is4In6(%q) = %v; want %v", tt.ip, got, tt.want)
+ }
+ u := tt.ip.Unmap()
+ if u != tt.wantUnmap {
+ t.Errorf("Unmap(%q) = %v; want %v", tt.ip, u, tt.wantUnmap)
+ }
+ }
+}
+
+func TestPrefixMasked(t *testing.T) {
+ tests := []struct {
+ prefix Prefix
+ masked Prefix
+ }{
+ {
+ prefix: mustPrefix("192.168.0.255/24"),
+ masked: mustPrefix("192.168.0.0/24"),
+ },
+ {
+ prefix: mustPrefix("2100::/3"),
+ masked: mustPrefix("2000::/3"),
+ },
+ {
+ prefix: PrefixFrom(mustIP("2000::"), 129),
+ masked: Prefix{},
+ },
+ {
+ prefix: PrefixFrom(mustIP("1.2.3.4"), 33),
+ masked: Prefix{},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.prefix.String(), func(t *testing.T) {
+ got := test.prefix.Masked()
+ if got != test.masked {
+ t.Errorf("Masked=%s, want %s", got, test.masked)
+ }
+ })
+ }
+}
+
+func TestPrefix(t *testing.T) {
+ tests := []struct {
+ prefix string
+ ip Addr
+ bits int
+ str string
+ contains []Addr
+ notContains []Addr
+ }{
+ {
+ prefix: "192.168.0.0/24",
+ ip: mustIP("192.168.0.0"),
+ bits: 24,
+ contains: mustIPs("192.168.0.1", "192.168.0.55"),
+ notContains: mustIPs("192.168.1.1", "1.1.1.1"),
+ },
+ {
+ prefix: "192.168.1.1/32",
+ ip: mustIP("192.168.1.1"),
+ bits: 32,
+ contains: mustIPs("192.168.1.1"),
+ notContains: mustIPs("192.168.1.2"),
+ },
+ {
+ prefix: "100.64.0.0/10", // CGNAT range; prefix not multiple of 8
+ ip: mustIP("100.64.0.0"),
+ bits: 10,
+ contains: mustIPs("100.64.0.0", "100.64.0.1", "100.81.251.94", "100.100.100.100", "100.127.255.254", "100.127.255.255"),
+ notContains: mustIPs("100.63.255.255", "100.128.0.0"),
+ },
+ {
+ prefix: "2001:db8::/96",
+ ip: mustIP("2001:db8::"),
+ bits: 96,
+ contains: mustIPs("2001:db8::aaaa:bbbb", "2001:db8::1"),
+ notContains: mustIPs("2001:db8::1:aaaa:bbbb", "2001:db9::"),
+ },
+ {
+ prefix: "0.0.0.0/0",
+ ip: mustIP("0.0.0.0"),
+ bits: 0,
+ contains: mustIPs("192.168.0.1", "1.1.1.1"),
+ notContains: append(mustIPs("2001:db8::1"), Addr{}),
+ },
+ {
+ prefix: "::/0",
+ ip: mustIP("::"),
+ bits: 0,
+ contains: mustIPs("::1", "2001:db8::1"),
+ notContains: mustIPs("192.0.2.1"),
+ },
+ {
+ prefix: "2000::/3",
+ ip: mustIP("2000::"),
+ bits: 3,
+ contains: mustIPs("2001:db8::1"),
+ notContains: mustIPs("fe80::1"),
+ },
+ {
+ prefix: "::%0/00/80",
+ ip: mustIP("::"),
+ bits: 80,
+ str: "::/80",
+ contains: mustIPs("::"),
+ notContains: mustIPs("ff::%0/00", "ff::%1/23", "::%0/00", "::%1/23"),
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.prefix, func(t *testing.T) {
+ prefix, err := ParsePrefix(test.prefix)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if prefix.Addr() != test.ip {
+ t.Errorf("IP=%s, want %s", prefix.Addr(), test.ip)
+ }
+ if prefix.Bits() != test.bits {
+ t.Errorf("bits=%d, want %d", prefix.Bits(), test.bits)
+ }
+ for _, ip := range test.contains {
+ if !prefix.Contains(ip) {
+ t.Errorf("does not contain %s", ip)
+ }
+ }
+ for _, ip := range test.notContains {
+ if prefix.Contains(ip) {
+ t.Errorf("contains %s", ip)
+ }
+ }
+ want := test.str
+ if want == "" {
+ want = test.prefix
+ }
+ if got := prefix.String(); got != want {
+ t.Errorf("prefix.String()=%q, want %q", got, want)
+ }
+
+ TestAppendToMarshal(t, prefix)
+ })
+ }
+}
+
+func TestPrefixFromInvalidBits(t *testing.T) {
+ v4 := MustParseAddr("1.2.3.4")
+ v6 := MustParseAddr("66::66")
+ tests := []struct {
+ ip Addr
+ in, want int
+ }{
+ {v4, 0, 0},
+ {v6, 0, 0},
+ {v4, 1, 1},
+ {v4, 33, -1},
+ {v6, 33, 33},
+ {v6, 127, 127},
+ {v6, 128, 128},
+ {v4, 254, -1},
+ {v4, 255, -1},
+ {v4, -1, -1},
+ {v6, -1, -1},
+ {v4, -5, -1},
+ {v6, -5, -1},
+ }
+ for _, tt := range tests {
+ p := PrefixFrom(tt.ip, tt.in)
+ if got := p.Bits(); got != tt.want {
+ t.Errorf("for (%v, %v), Bits out = %v; want %v", tt.ip, tt.in, got, tt.want)
+ }
+ }
+}
+
+func TestParsePrefixAllocs(t *testing.T) {
+ tests := []struct {
+ ip string
+ slash string
+ }{
+ {"192.168.1.0", "/24"},
+ {"aaaa:bbbb:cccc::", "/24"},
+ }
+ for _, test := range tests {
+ prefix := test.ip + test.slash
+ t.Run(prefix, func(t *testing.T) {
+ ipAllocs := int(testing.AllocsPerRun(5, func() {
+ ParseAddr(test.ip)
+ }))
+ prefixAllocs := int(testing.AllocsPerRun(5, func() {
+ ParsePrefix(prefix)
+ }))
+ if got := prefixAllocs - ipAllocs; got != 0 {
+ t.Errorf("allocs=%d, want 0", got)
+ }
+ })
+ }
+}
+
+func TestParsePrefixError(t *testing.T) {
+ tests := []struct {
+ prefix string
+ errstr string
+ }{
+ {
+ prefix: "192.168.0.0",
+ errstr: "no '/'",
+ },
+ {
+ prefix: "1.257.1.1/24",
+ errstr: "value >255",
+ },
+ {
+ prefix: "1.1.1.0/q",
+ errstr: "bad bits",
+ },
+ {
+ prefix: "1.1.1.0/-1",
+ errstr: "out of range",
+ },
+ {
+ prefix: "1.1.1.0/33",
+ errstr: "out of range",
+ },
+ {
+ prefix: "2001::/129",
+ errstr: "out of range",
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.prefix, func(t *testing.T) {
+ _, err := ParsePrefix(test.prefix)
+ if err == nil {
+ t.Fatal("no error")
+ }
+ if got := err.Error(); !strings.Contains(got, test.errstr) {
+ t.Errorf("error is missing substring %q: %s", test.errstr, got)
+ }
+ })
+ }
+}
+
+func TestPrefixIsSingleIP(t *testing.T) {
+ tests := []struct {
+ ipp Prefix
+ want bool
+ }{
+ {ipp: mustPrefix("127.0.0.1/32"), want: true},
+ {ipp: mustPrefix("127.0.0.1/31"), want: false},
+ {ipp: mustPrefix("127.0.0.1/0"), want: false},
+ {ipp: mustPrefix("::1/128"), want: true},
+ {ipp: mustPrefix("::1/127"), want: false},
+ {ipp: mustPrefix("::1/0"), want: false},
+ {ipp: Prefix{}, want: false},
+ }
+ for _, tt := range tests {
+ got := tt.ipp.IsSingleIP()
+ if got != tt.want {
+ t.Errorf("IsSingleIP(%v) = %v want %v", tt.ipp, got, tt.want)
+ }
+ }
+}
+
+func mustIPs(strs ...string) []Addr {
+ var res []Addr
+ for _, s := range strs {
+ res = append(res, mustIP(s))
+ }
+ return res
+}
+
+func BenchmarkBinaryMarshalRoundTrip(b *testing.B) {
+ b.ReportAllocs()
+ tests := []struct {
+ name string
+ ip string
+ }{
+ {"ipv4", "1.2.3.4"},
+ {"ipv6", "2001:db8::1"},
+ {"ipv6+zone", "2001:db8::1%eth0"},
+ }
+ for _, tc := range tests {
+ b.Run(tc.name, func(b *testing.B) {
+ ip := mustIP(tc.ip)
+ for i := 0; i < b.N; i++ {
+ bt, err := ip.MarshalBinary()
+ if err != nil {
+ b.Fatal(err)
+ }
+ var ip2 Addr
+ if err := ip2.UnmarshalBinary(bt); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkStdIPv4(b *testing.B) {
+ b.ReportAllocs()
+ ips := []net.IP{}
+ for i := 0; i < b.N; i++ {
+ ip := net.IPv4(8, 8, 8, 8)
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+func BenchmarkIPv4(b *testing.B) {
+ b.ReportAllocs()
+ ips := []Addr{}
+ for i := 0; i < b.N; i++ {
+ ip := IPv4(8, 8, 8, 8)
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+// ip4i was one of the possible representations of IP that came up in
+// discussions, inlining IPv4 addresses, but having an "overflow"
+// interface for IPv6 or IPv6 + zone. This is here for benchmarking.
+type ip4i struct {
+ ip4 [4]byte
+ flags1 byte
+ flags2 byte
+ flags3 byte
+ flags4 byte
+ ipv6 any
+}
+
+func newip4i_v4(a, b, c, d byte) ip4i {
+ return ip4i{ip4: [4]byte{a, b, c, d}}
+}
+
+// BenchmarkIPv4_inline benchmarks the candidate representation, ip4i.
+func BenchmarkIPv4_inline(b *testing.B) {
+ b.ReportAllocs()
+ ips := []ip4i{}
+ for i := 0; i < b.N; i++ {
+ ip := newip4i_v4(8, 8, 8, 8)
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+func BenchmarkStdIPv6(b *testing.B) {
+ b.ReportAllocs()
+ ips := []net.IP{}
+ for i := 0; i < b.N; i++ {
+ ip := net.ParseIP("2001:db8::1")
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+func BenchmarkIPv6(b *testing.B) {
+ b.ReportAllocs()
+ ips := []Addr{}
+ for i := 0; i < b.N; i++ {
+ ip := mustIP("2001:db8::1")
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+func BenchmarkIPv4Contains(b *testing.B) {
+ b.ReportAllocs()
+ prefix := PrefixFrom(IPv4(192, 168, 1, 0), 24)
+ ip := IPv4(192, 168, 1, 1)
+ for i := 0; i < b.N; i++ {
+ prefix.Contains(ip)
+ }
+}
+
+func BenchmarkIPv6Contains(b *testing.B) {
+ b.ReportAllocs()
+ prefix := MustParsePrefix("::1/128")
+ ip := MustParseAddr("::1")
+ for i := 0; i < b.N; i++ {
+ prefix.Contains(ip)
+ }
+}
+
+var parseBenchInputs = []struct {
+ name string
+ ip string
+}{
+ {"v4", "192.168.1.1"},
+ {"v6", "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b"},
+ {"v6_ellipsis", "fd7a:115c::626b:430b"},
+ {"v6_v4", "::ffff:192.168.140.255"},
+ {"v6_zone", "1:2::ffff:192.168.140.255%eth1"},
+}
+
+func BenchmarkParseAddr(b *testing.B) {
+ sinkInternValue = intern.Get("eth1") // Pin to not benchmark the intern package
+ for _, test := range parseBenchInputs {
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkIP, _ = ParseAddr(test.ip)
+ }
+ })
+ }
+}
+
+func BenchmarkStdParseIP(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkStdIP = net.ParseIP(test.ip)
+ }
+ })
+ }
+}
+
+func BenchmarkIPString(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ ip := MustParseAddr(test.ip)
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkString = ip.String()
+ }
+ })
+ }
+}
+
+func BenchmarkIPStringExpanded(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ ip := MustParseAddr(test.ip)
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkString = ip.StringExpanded()
+ }
+ })
+ }
+}
+
+func BenchmarkIPMarshalText(b *testing.B) {
+ b.ReportAllocs()
+ ip := MustParseAddr("66.55.44.33")
+ for i := 0; i < b.N; i++ {
+ sinkBytes, _ = ip.MarshalText()
+ }
+}
+
+func BenchmarkAddrPortString(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ ip := MustParseAddr(test.ip)
+ ipp := AddrPortFrom(ip, 60000)
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkString = ipp.String()
+ }
+ })
+ }
+}
+
+func BenchmarkAddrPortMarshalText(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ ip := MustParseAddr(test.ip)
+ ipp := AddrPortFrom(ip, 60000)
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkBytes, _ = ipp.MarshalText()
+ }
+ })
+ }
+}
+
+func BenchmarkPrefixMasking(b *testing.B) {
+ tests := []struct {
+ name string
+ ip Addr
+ bits int
+ }{
+ {
+ name: "IPv4 /32",
+ ip: IPv4(192, 0, 2, 0),
+ bits: 32,
+ },
+ {
+ name: "IPv4 /17",
+ ip: IPv4(192, 0, 2, 0),
+ bits: 17,
+ },
+ {
+ name: "IPv4 /0",
+ ip: IPv4(192, 0, 2, 0),
+ bits: 0,
+ },
+ {
+ name: "IPv6 /128",
+ ip: mustIP("2001:db8::1"),
+ bits: 128,
+ },
+ {
+ name: "IPv6 /65",
+ ip: mustIP("2001:db8::1"),
+ bits: 65,
+ },
+ {
+ name: "IPv6 /0",
+ ip: mustIP("2001:db8::1"),
+ bits: 0,
+ },
+ {
+ name: "IPv6 zone /128",
+ ip: mustIP("2001:db8::1%eth0"),
+ bits: 128,
+ },
+ {
+ name: "IPv6 zone /65",
+ ip: mustIP("2001:db8::1%eth0"),
+ bits: 65,
+ },
+ {
+ name: "IPv6 zone /0",
+ ip: mustIP("2001:db8::1%eth0"),
+ bits: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ b.Run(tt.name, func(b *testing.B) {
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ sinkPrefix, _ = tt.ip.Prefix(tt.bits)
+ }
+ })
+ }
+}
+
+func BenchmarkPrefixMarshalText(b *testing.B) {
+ b.ReportAllocs()
+ ipp := MustParsePrefix("66.55.44.33/22")
+ for i := 0; i < b.N; i++ {
+ sinkBytes, _ = ipp.MarshalText()
+ }
+}
+
+func BenchmarkParseAddrPort(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ var ipp string
+ if strings.HasPrefix(test.name, "v6") {
+ ipp = fmt.Sprintf("[%s]:1234", test.ip)
+ } else {
+ ipp = fmt.Sprintf("%s:1234", test.ip)
+ }
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ sinkAddrPort, _ = ParseAddrPort(ipp)
+ }
+ })
+ }
+}
+
+func TestAs4(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ want [4]byte
+ wantPanic bool
+ }{
+ {
+ ip: mustIP("1.2.3.4"),
+ want: [4]byte{1, 2, 3, 4},
+ },
+ {
+ ip: AddrFrom16(mustIP("1.2.3.4").As16()), // IPv4-in-IPv6
+ want: [4]byte{1, 2, 3, 4},
+ },
+ {
+ ip: mustIP("0.0.0.0"),
+ want: [4]byte{0, 0, 0, 0},
+ },
+ {
+ ip: Addr{},
+ wantPanic: true,
+ },
+ {
+ ip: mustIP("::1"),
+ wantPanic: true,
+ },
+ }
+ as4 := func(ip Addr) (v [4]byte, gotPanic bool) {
+ defer func() {
+ if recover() != nil {
+ gotPanic = true
+ return
+ }
+ }()
+ v = ip.As4()
+ return
+ }
+ for i, tt := range tests {
+ got, gotPanic := as4(tt.ip)
+ if gotPanic != tt.wantPanic {
+ t.Errorf("%d. panic on %v = %v; want %v", i, tt.ip, gotPanic, tt.wantPanic)
+ continue
+ }
+ if got != tt.want {
+ t.Errorf("%d. %v = %v; want %v", i, tt.ip, got, tt.want)
+ }
+ }
+}
+
+func TestPrefixOverlaps(t *testing.T) {
+ pfx := mustPrefix
+ tests := []struct {
+ a, b Prefix
+ want bool
+ }{
+ {Prefix{}, pfx("1.2.0.0/16"), false}, // first zero
+ {pfx("1.2.0.0/16"), Prefix{}, false}, // second zero
+ {pfx("::0/3"), pfx("0.0.0.0/3"), false}, // different families
+
+ {pfx("1.2.0.0/16"), pfx("1.2.0.0/16"), true}, // equal
+
+ {pfx("1.2.0.0/16"), pfx("1.2.3.0/24"), true},
+ {pfx("1.2.3.0/24"), pfx("1.2.0.0/16"), true},
+
+ {pfx("1.2.0.0/16"), pfx("1.2.3.0/32"), true},
+ {pfx("1.2.3.0/32"), pfx("1.2.0.0/16"), true},
+
+ // Match /0 either order
+ {pfx("1.2.3.0/32"), pfx("0.0.0.0/0"), true},
+ {pfx("0.0.0.0/0"), pfx("1.2.3.0/32"), true},
+
+ {pfx("1.2.3.0/32"), pfx("5.5.5.5/0"), true}, // normalization not required; /0 means true
+
+ // IPv6 overlapping
+ {pfx("5::1/128"), pfx("5::0/8"), true},
+ {pfx("5::0/8"), pfx("5::1/128"), true},
+
+ // IPv6 not overlapping
+ {pfx("1::1/128"), pfx("2::2/128"), false},
+ {pfx("0100::0/8"), pfx("::1/128"), false},
+
+ // v6-mapped v4 should not overlap with IPv4.
+ {PrefixFrom(AddrFrom16(mustIP("1.2.0.0").As16()), 16), pfx("1.2.3.0/24"), false},
+
+ // Invalid prefixes
+ {PrefixFrom(mustIP("1.2.3.4"), 33), pfx("1.2.3.0/24"), false},
+ {PrefixFrom(mustIP("2000::"), 129), pfx("2000::/64"), false},
+ }
+ for i, tt := range tests {
+ if got := tt.a.Overlaps(tt.b); got != tt.want {
+ t.Errorf("%d. (%v).Overlaps(%v) = %v; want %v", i, tt.a, tt.b, got, tt.want)
+ }
+ // Overlaps is commutative
+ if got := tt.b.Overlaps(tt.a); got != tt.want {
+ t.Errorf("%d. (%v).Overlaps(%v) = %v; want %v", i, tt.b, tt.a, got, tt.want)
+ }
+ }
+}
+
+// Sink variables are here to force the compiler to not elide
+// seemingly useless work in benchmarks and allocation tests. If you
+// were to just `_ = foo()` within a test function, the compiler could
+// correctly deduce that foo() does nothing and doesn't need to be
+// called. By writing results to a global variable, we hide that fact
+// from the compiler and force it to keep the code under test.
+var (
+ sinkIP Addr
+ sinkStdIP net.IP
+ sinkAddrPort AddrPort
+ sinkPrefix Prefix
+ sinkPrefixSlice []Prefix
+ sinkInternValue *intern.Value
+ sinkIP16 [16]byte
+ sinkIP4 [4]byte
+ sinkBool bool
+ sinkString string
+ sinkBytes []byte
+ sinkUDPAddr = &net.UDPAddr{IP: make(net.IP, 0, 16)}
+)
+
+func TestNoAllocs(t *testing.T) {
+ // Wrappers that panic on error, to prove that our alloc-free
+ // methods are returning successfully.
+ panicIP := func(ip Addr, err error) Addr {
+ if err != nil {
+ panic(err)
+ }
+ return ip
+ }
+ panicPfx := func(pfx Prefix, err error) Prefix {
+ if err != nil {
+ panic(err)
+ }
+ return pfx
+ }
+ panicIPP := func(ipp AddrPort, err error) AddrPort {
+ if err != nil {
+ panic(err)
+ }
+ return ipp
+ }
+ test := func(name string, f func()) {
+ t.Run(name, func(t *testing.T) {
+ n := testing.AllocsPerRun(1000, f)
+ if n != 0 {
+ t.Fatalf("allocs = %d; want 0", int(n))
+ }
+ })
+ }
+
+ // IP constructors
+ test("IPv4", func() { sinkIP = IPv4(1, 2, 3, 4) })
+ test("AddrFrom4", func() { sinkIP = AddrFrom4([4]byte{1, 2, 3, 4}) })
+ test("AddrFrom16", func() { sinkIP = AddrFrom16([16]byte{}) })
+ test("ParseAddr/4", func() { sinkIP = panicIP(ParseAddr("1.2.3.4")) })
+ test("ParseAddr/6", func() { sinkIP = panicIP(ParseAddr("::1")) })
+ test("MustParseAddr", func() { sinkIP = MustParseAddr("1.2.3.4") })
+ test("IPv6LinkLocalAllNodes", func() { sinkIP = IPv6LinkLocalAllNodes() })
+ test("IPv6Unspecified", func() { sinkIP = IPv6Unspecified() })
+
+ // IP methods
+ test("IP.IsZero", func() { sinkBool = MustParseAddr("1.2.3.4").IsZero() })
+ test("IP.BitLen", func() { sinkBool = MustParseAddr("1.2.3.4").BitLen() == 8 })
+ test("IP.Zone/4", func() { sinkBool = MustParseAddr("1.2.3.4").Zone() == "" })
+ test("IP.Zone/6", func() { sinkBool = MustParseAddr("fe80::1").Zone() == "" })
+ test("IP.Zone/6zone", func() { sinkBool = MustParseAddr("fe80::1%zone").Zone() == "" })
+ test("IP.Compare", func() {
+ a := MustParseAddr("1.2.3.4")
+ b := MustParseAddr("2.3.4.5")
+ sinkBool = a.Compare(b) == 0
+ })
+ test("IP.Less", func() {
+ a := MustParseAddr("1.2.3.4")
+ b := MustParseAddr("2.3.4.5")
+ sinkBool = a.Less(b)
+ })
+ test("IP.Is4", func() { sinkBool = MustParseAddr("1.2.3.4").Is4() })
+ test("IP.Is6", func() { sinkBool = MustParseAddr("fe80::1").Is6() })
+ test("IP.Is4In6", func() { sinkBool = MustParseAddr("fe80::1").Is4In6() })
+ test("IP.Unmap", func() { sinkIP = MustParseAddr("ffff::2.3.4.5").Unmap() })
+ test("IP.WithZone", func() { sinkIP = MustParseAddr("fe80::1").WithZone("") })
+ test("IP.IsGlobalUnicast", func() { sinkBool = MustParseAddr("2001:db8::1").IsGlobalUnicast() })
+ test("IP.IsInterfaceLocalMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsInterfaceLocalMulticast() })
+ test("IP.IsLinkLocalMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsLinkLocalMulticast() })
+ test("IP.IsLinkLocalUnicast", func() { sinkBool = MustParseAddr("fe80::1").IsLinkLocalUnicast() })
+ test("IP.IsLoopback", func() { sinkBool = MustParseAddr("fe80::1").IsLoopback() })
+ test("IP.IsMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsMulticast() })
+ test("IP.IsPrivate", func() { sinkBool = MustParseAddr("fd00::1").IsPrivate() })
+ test("IP.IsUnspecified", func() { sinkBool = IPv6Unspecified().IsUnspecified() })
+ test("IP.Prefix/4", func() { sinkPrefix = panicPfx(MustParseAddr("1.2.3.4").Prefix(20)) })
+ test("IP.Prefix/6", func() { sinkPrefix = panicPfx(MustParseAddr("fe80::1").Prefix(64)) })
+ test("IP.As16", func() { sinkIP16 = MustParseAddr("1.2.3.4").As16() })
+ test("IP.As4", func() { sinkIP4 = MustParseAddr("1.2.3.4").As4() })
+ test("IP.Next", func() { sinkIP = MustParseAddr("1.2.3.4").Next() })
+ test("IP.Prev", func() { sinkIP = MustParseAddr("1.2.3.4").Prev() })
+
+ // AddrPort constructors
+ test("AddrPortFrom", func() { sinkAddrPort = AddrPortFrom(IPv4(1, 2, 3, 4), 22) })
+ test("ParseAddrPort", func() { sinkAddrPort = panicIPP(ParseAddrPort("[::1]:1234")) })
+ test("MustParseAddrPort", func() { sinkAddrPort = MustParseAddrPort("[::1]:1234") })
+
+ // Prefix constructors
+ test("PrefixFrom", func() { sinkPrefix = PrefixFrom(IPv4(1, 2, 3, 4), 32) })
+ test("ParsePrefix/4", func() { sinkPrefix = panicPfx(ParsePrefix("1.2.3.4/20")) })
+ test("ParsePrefix/6", func() { sinkPrefix = panicPfx(ParsePrefix("fe80::1/64")) })
+ test("MustParsePrefix", func() { sinkPrefix = MustParsePrefix("1.2.3.4/20") })
+
+ // Prefix methods
+ test("Prefix.Contains", func() { sinkBool = MustParsePrefix("1.2.3.0/24").Contains(MustParseAddr("1.2.3.4")) })
+ test("Prefix.Overlaps", func() {
+ a, b := MustParsePrefix("1.2.3.0/24"), MustParsePrefix("1.2.0.0/16")
+ sinkBool = a.Overlaps(b)
+ })
+ test("Prefix.IsZero", func() { sinkBool = MustParsePrefix("1.2.0.0/16").IsZero() })
+ test("Prefix.IsSingleIP", func() { sinkBool = MustParsePrefix("1.2.3.4/32").IsSingleIP() })
+ test("IPPRefix.Masked", func() { sinkPrefix = MustParsePrefix("1.2.3.4/16").Masked() })
+}
+
+func TestPrefixString(t *testing.T) {
+ tests := []struct {
+ ipp Prefix
+ want string
+ }{
+ {Prefix{}, "invalid Prefix"},
+ {PrefixFrom(Addr{}, 8), "invalid Prefix"},
+ {PrefixFrom(MustParseAddr("1.2.3.4"), 88), "invalid Prefix"},
+ }
+
+ for _, tt := range tests {
+ if got := tt.ipp.String(); got != tt.want {
+ t.Errorf("(%#v).String() = %q want %q", tt.ipp, got, tt.want)
+ }
+ }
+}
+
+func TestInvalidAddrPortString(t *testing.T) {
+ tests := []struct {
+ ipp AddrPort
+ want string
+ }{
+ {AddrPort{}, "invalid AddrPort"},
+ {AddrPortFrom(Addr{}, 80), "invalid AddrPort"},
+ }
+
+ for _, tt := range tests {
+ if got := tt.ipp.String(); got != tt.want {
+ t.Errorf("(%#v).String() = %q want %q", tt.ipp, got, tt.want)
+ }
+ }
+}
+
+func TestAsSlice(t *testing.T) {
+ tests := []struct {
+ in Addr
+ want []byte
+ }{
+ {in: Addr{}, want: nil},
+ {in: mustIP("1.2.3.4"), want: []byte{1, 2, 3, 4}},
+ {in: mustIP("ffff::1"), want: []byte{0xff, 0xff, 15: 1}},
+ }
+
+ for _, test := range tests {
+ got := test.in.AsSlice()
+ if !bytes.Equal(got, test.want) {
+ t.Errorf("%v.AsSlice() = %v want %v", test.in, got, test.want)
+ }
+ }
+}
+
+var sink16 [16]byte
+
+func BenchmarkAs16(b *testing.B) {
+ addr := MustParseAddr("1::10")
+ for i := 0; i < b.N; i++ {
+ sink16 = addr.As16()
+ }
+}
diff --git a/libgo/go/net/netip/slow_test.go b/libgo/go/net/netip/slow_test.go
new file mode 100644
index 0000000..5b46a39
--- /dev/null
+++ b/libgo/go/net/netip/slow_test.go
@@ -0,0 +1,190 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip_test
+
+import (
+ "fmt"
+ . "net/netip"
+ "strconv"
+ "strings"
+)
+
+// zeros is a slice of eight stringified zeros. It's used in
+// parseIPSlow to construct slices of specific amounts of zero fields,
+// from 1 to 8.
+var zeros = []string{"0", "0", "0", "0", "0", "0", "0", "0"}
+
+// parseIPSlow is like ParseIP, but aims for readability above
+// speed. It's the reference implementation for correctness checking
+// and against which we measure optimized parsers.
+//
+// parseIPSlow understands the following forms of IP addresses:
+// - Regular IPv4: 1.2.3.4
+// - IPv4 with many leading zeros: 0000001.0000002.0000003.0000004
+// - Regular IPv6: 1111:2222:3333:4444:5555:6666:7777:8888
+// - IPv6 with many leading zeros: 00000001:0000002:0000003:0000004:0000005:0000006:0000007:0000008
+// - IPv6 with zero blocks elided: 1111:2222::7777:8888
+// - IPv6 with trailing 32 bits expressed as IPv4: 1111:2222:3333:4444:5555:6666:77.77.88.88
+//
+// It does not process the following IP address forms, which have been
+// varyingly accepted by some programs due to an under-specification
+// of the shapes of IPv4 addresses:
+//
+// - IPv4 as a single 32-bit uint: 4660 (same as "1.2.3.4")
+// - IPv4 with octal numbers: 0300.0250.0.01 (same as "192.168.0.1")
+// - IPv4 with hex numbers: 0xc0.0xa8.0x0.0x1 (same as "192.168.0.1")
+// - IPv4 in "class-B style": 1.2.52 (same as "1.2.3.4")
+// - IPv4 in "class-A style": 1.564 (same as "1.2.3.4")
+func parseIPSlow(s string) (Addr, error) {
+ // Identify and strip out the zone, if any. There should be 0 or 1
+ // '%' in the string.
+ var zone string
+ fs := strings.Split(s, "%")
+ switch len(fs) {
+ case 1:
+ // No zone, that's fine.
+ case 2:
+ s, zone = fs[0], fs[1]
+ if zone == "" {
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): no zone after zone specifier", s)
+ }
+ default:
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): too many zone specifiers", s) // TODO: less specific?
+ }
+
+ // IPv4 by itself is easy to do in a helper.
+ if strings.Count(s, ":") == 0 {
+ if zone != "" {
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): IPv4 addresses cannot have a zone", s)
+ }
+ return parseIPv4Slow(s)
+ }
+
+ normal, err := normalizeIPv6Slow(s)
+ if err != nil {
+ return Addr{}, err
+ }
+
+ // At this point, we've normalized the address back into 8 hex
+ // fields of 16 bits each. Parse that.
+ fs = strings.Split(normal, ":")
+ if len(fs) != 8 {
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): wrong size address", s)
+ }
+ var ret [16]byte
+ for i, f := range fs {
+ a, b, err := parseWord(f)
+ if err != nil {
+ return Addr{}, err
+ }
+ ret[i*2] = a
+ ret[i*2+1] = b
+ }
+
+ return AddrFrom16(ret).WithZone(zone), nil
+}
+
+// normalizeIPv6Slow expands s, which is assumed to be an IPv6
+// address, to its canonical text form.
+//
+// The canonical form of an IPv6 address is 8 colon-separated fields,
+// where each field should be a hex value from 0 to ffff. This
+// function does not verify the contents of each field.
+//
+// This function performs two transformations:
+// - The last 32 bits of an IPv6 address may be represented in
+// IPv4-style dotted quad form, as in 1:2:3:4:5:6:7.8.9.10. That
+// address is transformed to its hex equivalent,
+// e.g. 1:2:3:4:5:6:708:90a.
+// - An address may contain one "::", which expands into as many
+// 16-bit blocks of zeros as needed to make the address its correct
+// full size. For example, fe80::1:2 expands to fe80:0:0:0:0:0:1:2.
+//
+// Both short forms may be present in a single address,
+// e.g. fe80::1.2.3.4.
+func normalizeIPv6Slow(orig string) (string, error) {
+ s := orig
+
+ // Find and convert an IPv4 address in the final field, if any.
+ i := strings.LastIndex(s, ":")
+ if i == -1 {
+ return "", fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", orig)
+ }
+ if strings.Contains(s[i+1:], ".") {
+ ip, err := parseIPv4Slow(s[i+1:])
+ if err != nil {
+ return "", err
+ }
+ a4 := ip.As4()
+ s = fmt.Sprintf("%s:%02x%02x:%02x%02x", s[:i], a4[0], a4[1], a4[2], a4[3])
+ }
+
+ // Find and expand a ::, if any.
+ fs := strings.Split(s, "::")
+ switch len(fs) {
+ case 1:
+ // No ::, nothing to do.
+ case 2:
+ lhs, rhs := fs[0], fs[1]
+ // Found a ::, figure out how many zero blocks need to be
+ // inserted.
+ nblocks := strings.Count(lhs, ":") + strings.Count(rhs, ":")
+ if lhs != "" {
+ nblocks++
+ }
+ if rhs != "" {
+ nblocks++
+ }
+ if nblocks > 7 {
+ return "", fmt.Errorf("netaddr.ParseIP(%q): address too long", orig)
+ }
+ fs = nil
+ // Either side of the :: can be empty. We don't want empty
+ // fields to feature in the final normalized address.
+ if lhs != "" {
+ fs = append(fs, lhs)
+ }
+ fs = append(fs, zeros[:8-nblocks]...)
+ if rhs != "" {
+ fs = append(fs, rhs)
+ }
+ s = strings.Join(fs, ":")
+ default:
+ // Too many ::
+ return "", fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", orig)
+ }
+
+ return s, nil
+}
+
+// parseIPv4Slow parses and returns an IPv4 address in dotted quad
+// form, e.g. "192.168.0.1". It is slow but easy to read, and the
+// reference implementation against which we compare faster
+// implementations for correctness.
+func parseIPv4Slow(s string) (Addr, error) {
+ fs := strings.Split(s, ".")
+ if len(fs) != 4 {
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", s)
+ }
+ var ret [4]byte
+ for i := range ret {
+ val, err := strconv.ParseUint(fs[i], 10, 8)
+ if err != nil {
+ return Addr{}, err
+ }
+ ret[i] = uint8(val)
+ }
+ return AddrFrom4([4]byte{ret[0], ret[1], ret[2], ret[3]}), nil
+}
+
+// parseWord converts a 16-bit hex string into its corresponding
+// two-byte value.
+func parseWord(s string) (byte, byte, error) {
+ ret, err := strconv.ParseUint(s, 16, 16)
+ if err != nil {
+ return 0, 0, err
+ }
+ return uint8(ret >> 8), uint8(ret), nil
+}
diff --git a/libgo/go/net/netip/uint128.go b/libgo/go/net/netip/uint128.go
new file mode 100644
index 0000000..738939d
--- /dev/null
+++ b/libgo/go/net/netip/uint128.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import "math/bits"
+
+// uint128 represents a uint128 using two uint64s.
+//
+// When the methods below mention a bit number, bit 0 is the most
+// significant bit (in hi) and bit 127 is the lowest (lo&1).
+type uint128 struct {
+ hi uint64
+ lo uint64
+}
+
+// mask6 returns a uint128 bitmask with the topmost n bits of a
+// 128-bit number.
+func mask6(n int) uint128 {
+ return uint128{^(^uint64(0) >> n), ^uint64(0) << (128 - n)}
+}
+
+// isZero reports whether u == 0.
+//
+// It's faster than u == (uint128{}) because the compiler (as of Go
+// 1.15/1.16b1) doesn't do this trick and instead inserts a branch in
+// its eq alg's generated code.
+func (u uint128) isZero() bool { return u.hi|u.lo == 0 }
+
+// and returns the bitwise AND of u and m (u&m).
+func (u uint128) and(m uint128) uint128 {
+ return uint128{u.hi & m.hi, u.lo & m.lo}
+}
+
+// xor returns the bitwise XOR of u and m (u^m).
+func (u uint128) xor(m uint128) uint128 {
+ return uint128{u.hi ^ m.hi, u.lo ^ m.lo}
+}
+
+// or returns the bitwise OR of u and m (u|m).
+func (u uint128) or(m uint128) uint128 {
+ return uint128{u.hi | m.hi, u.lo | m.lo}
+}
+
+// not returns the bitwise NOT of u.
+func (u uint128) not() uint128 {
+ return uint128{^u.hi, ^u.lo}
+}
+
+// subOne returns u - 1.
+func (u uint128) subOne() uint128 {
+ lo, borrow := bits.Sub64(u.lo, 1, 0)
+ return uint128{u.hi - borrow, lo}
+}
+
+// addOne returns u + 1.
+func (u uint128) addOne() uint128 {
+ lo, carry := bits.Add64(u.lo, 1, 0)
+ return uint128{u.hi + carry, lo}
+}
+
+func u64CommonPrefixLen(a, b uint64) uint8 {
+ return uint8(bits.LeadingZeros64(a ^ b))
+}
+
+func (u uint128) commonPrefixLen(v uint128) (n uint8) {
+ if n = u64CommonPrefixLen(u.hi, v.hi); n == 64 {
+ n += u64CommonPrefixLen(u.lo, v.lo)
+ }
+ return
+}
+
+// halves returns the two uint64 halves of the uint128.
+//
+// Logically, think of it as returning two uint64s.
+// It only returns pointers for inlining reasons on 32-bit platforms.
+func (u *uint128) halves() [2]*uint64 {
+ return [2]*uint64{&u.hi, &u.lo}
+}
+
+// bitsSetFrom returns a copy of u with the given bit
+// and all subsequent ones set.
+func (u uint128) bitsSetFrom(bit uint8) uint128 {
+ return u.or(mask6(int(bit)).not())
+}
+
+// bitsClearedFrom returns a copy of u with the given bit
+// and all subsequent ones cleared.
+func (u uint128) bitsClearedFrom(bit uint8) uint128 {
+ return u.and(mask6(int(bit)))
+}
diff --git a/libgo/go/net/netip/uint128_test.go b/libgo/go/net/netip/uint128_test.go
new file mode 100644
index 0000000..dd1ae0e
--- /dev/null
+++ b/libgo/go/net/netip/uint128_test.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import (
+ "testing"
+)
+
+func TestUint128AddSub(t *testing.T) {
+ const add1 = 1
+ const sub1 = -1
+ tests := []struct {
+ in uint128
+ op int // +1 or -1 to add vs subtract
+ want uint128
+ }{
+ {uint128{0, 0}, add1, uint128{0, 1}},
+ {uint128{0, 1}, add1, uint128{0, 2}},
+ {uint128{1, 0}, add1, uint128{1, 1}},
+ {uint128{0, ^uint64(0)}, add1, uint128{1, 0}},
+ {uint128{^uint64(0), ^uint64(0)}, add1, uint128{0, 0}},
+
+ {uint128{0, 0}, sub1, uint128{^uint64(0), ^uint64(0)}},
+ {uint128{0, 1}, sub1, uint128{0, 0}},
+ {uint128{0, 2}, sub1, uint128{0, 1}},
+ {uint128{1, 0}, sub1, uint128{0, ^uint64(0)}},
+ {uint128{1, 1}, sub1, uint128{1, 0}},
+ }
+ for _, tt := range tests {
+ var got uint128
+ switch tt.op {
+ case add1:
+ got = tt.in.addOne()
+ case sub1:
+ got = tt.in.subOne()
+ default:
+ panic("bogus op")
+ }
+ if got != tt.want {
+ t.Errorf("%v add %d = %v; want %v", tt.in, tt.op, got, tt.want)
+ }
+ }
+}
+
+func TestBitsSetFrom(t *testing.T) {
+ tests := []struct {
+ bit uint8
+ want uint128
+ }{
+ {0, uint128{^uint64(0), ^uint64(0)}},
+ {1, uint128{^uint64(0) >> 1, ^uint64(0)}},
+ {63, uint128{1, ^uint64(0)}},
+ {64, uint128{0, ^uint64(0)}},
+ {65, uint128{0, ^uint64(0) >> 1}},
+ {127, uint128{0, 1}},
+ {128, uint128{0, 0}},
+ }
+ for _, tt := range tests {
+ var zero uint128
+ got := zero.bitsSetFrom(tt.bit)
+ if got != tt.want {
+ t.Errorf("0.bitsSetFrom(%d) = %064b want %064b", tt.bit, got, tt.want)
+ }
+ }
+}
+
+func TestBitsClearedFrom(t *testing.T) {
+ tests := []struct {
+ bit uint8
+ want uint128
+ }{
+ {0, uint128{0, 0}},
+ {1, uint128{1 << 63, 0}},
+ {63, uint128{^uint64(0) &^ 1, 0}},
+ {64, uint128{^uint64(0), 0}},
+ {65, uint128{^uint64(0), 1 << 63}},
+ {127, uint128{^uint64(0), ^uint64(0) &^ 1}},
+ {128, uint128{^uint64(0), ^uint64(0)}},
+ }
+ for _, tt := range tests {
+ ones := uint128{^uint64(0), ^uint64(0)}
+ got := ones.bitsClearedFrom(tt.bit)
+ if got != tt.want {
+ t.Errorf("ones.bitsClearedFrom(%d) = %064b want %064b", tt.bit, got, tt.want)
+ }
+ }
+}
diff --git a/libgo/go/net/nss.go b/libgo/go/net/nss.go
index c12ee75..3e5274d 100644
--- a/libgo/go/net/nss.go
+++ b/libgo/go/net/nss.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/nss_test.go b/libgo/go/net/nss_test.go
index 948b8d3..b9a23ab 100644
--- a/libgo/go/net/nss_test.go
+++ b/libgo/go/net/nss_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/packetconn_test.go b/libgo/go/net/packetconn_test.go
index aeb9845..fa160df 100644
--- a/libgo/go/net/packetconn_test.go
+++ b/libgo/go/net/packetconn_test.go
@@ -6,14 +6,12 @@
// tag.
//go:build !js
-// +build !js
package net
import (
"os"
"testing"
- "time"
)
// The full stack test cases for IPConn have been moved to the
@@ -29,16 +27,16 @@ func packetConnTestData(t *testing.T, network string) ([]byte, func()) {
return []byte("PACKETCONN TEST"), nil
}
-var packetConnTests = []struct {
- net string
- addr1 string
- addr2 string
-}{
- {"udp", "127.0.0.1:0", "127.0.0.1:0"},
- {"unixgram", testUnixAddr(), testUnixAddr()},
-}
-
func TestPacketConn(t *testing.T) {
+ var packetConnTests = []struct {
+ net string
+ addr1 string
+ addr2 string
+ }{
+ {"udp", "127.0.0.1:0", "127.0.0.1:0"},
+ {"unixgram", testUnixAddr(t), testUnixAddr(t)},
+ }
+
closer := func(c PacketConn, net, addr1, addr2 string) {
c.Close()
switch net {
@@ -61,9 +59,6 @@ func TestPacketConn(t *testing.T) {
}
defer closer(c1, tt.net, tt.addr1, tt.addr2)
c1.LocalAddr()
- c1.SetDeadline(time.Now().Add(500 * time.Millisecond))
- c1.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
- c1.SetWriteDeadline(time.Now().Add(500 * time.Millisecond))
c2, err := ListenPacket(tt.net, tt.addr2)
if err != nil {
@@ -71,9 +66,6 @@ func TestPacketConn(t *testing.T) {
}
defer closer(c2, tt.net, tt.addr1, tt.addr2)
c2.LocalAddr()
- c2.SetDeadline(time.Now().Add(500 * time.Millisecond))
- c2.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
- c2.SetWriteDeadline(time.Now().Add(500 * time.Millisecond))
rb2 := make([]byte, 128)
if _, err := c1.WriteTo(wb, c2.LocalAddr()); err != nil {
@@ -93,6 +85,15 @@ func TestPacketConn(t *testing.T) {
}
func TestConnAndPacketConn(t *testing.T) {
+ var packetConnTests = []struct {
+ net string
+ addr1 string
+ addr2 string
+ }{
+ {"udp", "127.0.0.1:0", "127.0.0.1:0"},
+ {"unixgram", testUnixAddr(t), testUnixAddr(t)},
+ }
+
closer := func(c PacketConn, net, addr1, addr2 string) {
c.Close()
switch net {
@@ -116,9 +117,6 @@ func TestConnAndPacketConn(t *testing.T) {
}
defer closer(c1, tt.net, tt.addr1, tt.addr2)
c1.LocalAddr()
- c1.SetDeadline(time.Now().Add(500 * time.Millisecond))
- c1.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
- c1.SetWriteDeadline(time.Now().Add(500 * time.Millisecond))
c2, err := Dial(tt.net, c1.LocalAddr().String())
if err != nil {
@@ -127,9 +125,6 @@ func TestConnAndPacketConn(t *testing.T) {
defer c2.Close()
c2.LocalAddr()
c2.RemoteAddr()
- c2.SetDeadline(time.Now().Add(500 * time.Millisecond))
- c2.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
- c2.SetWriteDeadline(time.Now().Add(500 * time.Millisecond))
if _, err := c2.Write(wb); err != nil {
t.Fatal(err)
diff --git a/libgo/go/net/parse.go b/libgo/go/net/parse.go
index 6c230ab..ee2890f 100644
--- a/libgo/go/net/parse.go
+++ b/libgo/go/net/parse.go
@@ -208,6 +208,16 @@ func last(s string, b byte) int {
return i
}
+// hasUpperCase tells whether the given string contains at least one upper-case.
+func hasUpperCase(s string) bool {
+ for i := range s {
+ if 'A' <= s[i] && s[i] <= 'Z' {
+ return true
+ }
+ }
+ return false
+}
+
// lowerASCIIBytes makes x ASCII lowercase in-place.
func lowerASCIIBytes(x []byte) {
for i, b := range x {
@@ -331,26 +341,3 @@ func readFull(r io.Reader) (all []byte, err error) {
}
}
}
-
-// goDebugString returns the value of the named GODEBUG key.
-// GODEBUG is of the form "key=val,key2=val2"
-func goDebugString(key string) string {
- s := os.Getenv("GODEBUG")
- for i := 0; i < len(s)-len(key)-1; i++ {
- if i > 0 && s[i-1] != ',' {
- continue
- }
- afterKey := s[i+len(key):]
- if afterKey[0] != '=' || s[i:i+len(key)] != key {
- continue
- }
- val := afterKey[1:]
- for i, b := range val {
- if b == ',' {
- return val[:i]
- }
- }
- return val
- }
- return ""
-}
diff --git a/libgo/go/net/parse_test.go b/libgo/go/net/parse_test.go
index c5f8bfd..97716d7 100644
--- a/libgo/go/net/parse_test.go
+++ b/libgo/go/net/parse_test.go
@@ -51,33 +51,6 @@ func TestReadLine(t *testing.T) {
}
}
-func TestGoDebugString(t *testing.T) {
- defer os.Setenv("GODEBUG", os.Getenv("GODEBUG"))
- tests := []struct {
- godebug string
- key string
- want string
- }{
- {"", "foo", ""},
- {"foo=", "foo", ""},
- {"foo=bar", "foo", "bar"},
- {"foo=bar,", "foo", "bar"},
- {"foo,foo=bar,", "foo", "bar"},
- {"foo1=bar,foo=bar,", "foo", "bar"},
- {"foo=bar,foo=bar,", "foo", "bar"},
- {"foo=", "foo", ""},
- {"foo", "foo", ""},
- {",foo", "foo", ""},
- {"foo=bar,baz", "loooooooong", ""},
- }
- for _, tt := range tests {
- os.Setenv("GODEBUG", tt.godebug)
- if got := goDebugString(tt.key); got != tt.want {
- t.Errorf("for %q, goDebugString(%q) = %q; want %q", tt.godebug, tt.key, got, tt.want)
- }
- }
-}
-
func TestDtoi(t *testing.T) {
for _, tt := range []struct {
in string
diff --git a/libgo/go/net/platform_test.go b/libgo/go/net/platform_test.go
index 2da23de..c522ba2 100644
--- a/libgo/go/net/platform_test.go
+++ b/libgo/go/net/platform_test.go
@@ -34,8 +34,8 @@ func init() {
// testableNetwork reports whether network is testable on the current
// platform configuration.
func testableNetwork(network string) bool {
- ss := strings.Split(network, ":")
- switch ss[0] {
+ net, _, _ := strings.Cut(network, ":")
+ switch net {
case "ip+nopriv":
case "ip", "ip4", "ip6":
switch runtime.GOOS {
@@ -68,7 +68,7 @@ func testableNetwork(network string) bool {
}
}
}
- switch ss[0] {
+ switch net {
case "tcp4", "udp4", "ip4":
if !supportsIPv4() {
return false
@@ -88,7 +88,7 @@ func iOS() bool {
// testableAddress reports whether address of network is testable on
// the current platform configuration.
func testableAddress(network, address string) bool {
- switch ss := strings.Split(network, ":"); ss[0] {
+ switch net, _, _ := strings.Cut(network, ":"); net {
case "unix", "unixgram", "unixpacket":
// Abstract unix domain sockets, a Linux-ism.
if address[0] == '@' && runtime.GOOS != "linux" {
@@ -107,7 +107,7 @@ func testableListenArgs(network, address, client string) bool {
var err error
var addr Addr
- switch ss := strings.Split(network, ":"); ss[0] {
+ switch net, _, _ := strings.Cut(network, ":"); net {
case "tcp", "tcp4", "tcp6":
addr, err = ResolveTCPAddr("tcp", address)
case "udp", "udp4", "udp6":
@@ -173,7 +173,7 @@ func testableListenArgs(network, address, client string) bool {
return true
}
-func condFatalf(t *testing.T, network string, format string, args ...interface{}) {
+func condFatalf(t *testing.T, network string, format string, args ...any) {
t.Helper()
// A few APIs like File and Read/WriteMsg{UDP,IP} are not
// fully implemented yet on Plan 9 and Windows.
diff --git a/libgo/go/net/port_unix.go b/libgo/go/net/port_unix.go
index 07b4cbb..3527f1f 100644
--- a/libgo/go/net/port_unix.go
+++ b/libgo/go/net/port_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris
// Read system port mappings from /etc/services
diff --git a/libgo/go/net/protoconn_test.go b/libgo/go/net/protoconn_test.go
index fc9b386..e4198a3 100644
--- a/libgo/go/net/protoconn_test.go
+++ b/libgo/go/net/protoconn_test.go
@@ -6,7 +6,6 @@
// tag.
//go:build !js
-// +build !js
package net
@@ -74,10 +73,7 @@ func TestTCPConnSpecificMethods(t *testing.T) {
}
ch := make(chan error, 1)
handler := func(ls *localServer, ln Listener) { ls.transponder(ls.Listener, ch) }
- ls, err := (&streamListener{Listener: ln}).newLocalServer()
- if err != nil {
- t.Fatal(err)
- }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -208,7 +204,7 @@ func TestUnixListenerSpecificMethods(t *testing.T) {
t.Skip("unix test")
}
- addr := testUnixAddr()
+ addr := testUnixAddr(t)
la, err := ResolveUnixAddr("unix", addr)
if err != nil {
t.Fatal(err)
@@ -249,7 +245,7 @@ func TestUnixConnSpecificMethods(t *testing.T) {
t.Skip("unixgram test")
}
- addr1, addr2, addr3 := testUnixAddr(), testUnixAddr(), testUnixAddr()
+ addr1, addr2, addr3 := testUnixAddr(t), testUnixAddr(t), testUnixAddr(t)
a1, err := ResolveUnixAddr("unixgram", addr1)
if err != nil {
diff --git a/libgo/go/net/rawconn_stub_test.go b/libgo/go/net/rawconn_stub_test.go
index 975aa8d..ff3d829 100644
--- a/libgo/go/net/rawconn_stub_test.go
+++ b/libgo/go/net/rawconn_stub_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build (js && wasm) || plan9
-// +build js,wasm plan9
package net
diff --git a/libgo/go/net/rawconn_test.go b/libgo/go/net/rawconn_test.go
index 3ef7af3..d1ef79d 100644
--- a/libgo/go/net/rawconn_test.go
+++ b/libgo/go/net/rawconn_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -65,10 +64,7 @@ func TestRawConnReadWrite(t *testing.T) {
return
}
}
- ls, err := newLocalServer("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, "tcp")
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -103,10 +99,7 @@ func TestRawConnReadWrite(t *testing.T) {
t.Skipf("not supported on %s", runtime.GOOS)
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
c, err := Dial(ln.Addr().Network(), ln.Addr().String())
@@ -181,10 +174,7 @@ func TestRawConnControl(t *testing.T) {
}
t.Run("TCP", func(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
cc1, err := ln.(*TCPListener).SyscallConn()
diff --git a/libgo/go/net/rawconn_unix_test.go b/libgo/go/net/rawconn_unix_test.go
index 77df4f8..7069d01 100644
--- a/libgo/go/net/rawconn_unix_test.go
+++ b/libgo/go/net/rawconn_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/rpc/client.go b/libgo/go/net/rpc/client.go
index 60bb2cc..42d1351 100644
--- a/libgo/go/net/rpc/client.go
+++ b/libgo/go/net/rpc/client.go
@@ -27,11 +27,11 @@ var ErrShutdown = errors.New("connection is shut down")
// Call represents an active RPC.
type Call struct {
- ServiceMethod string // The name of the service and method to call.
- Args interface{} // The argument to the function (*struct).
- Reply interface{} // The reply from the function (*struct).
- Error error // After completion, the error status.
- Done chan *Call // Receives *Call when Go is complete.
+ ServiceMethod string // The name of the service and method to call.
+ Args any // The argument to the function (*struct).
+ Reply any // The reply from the function (*struct).
+ Error error // After completion, the error status.
+ Done chan *Call // Receives *Call when Go is complete.
}
// Client represents an RPC Client.
@@ -61,9 +61,9 @@ type Client struct {
// discarded.
// See NewClient's comment for information about concurrent access.
type ClientCodec interface {
- WriteRequest(*Request, interface{}) error
+ WriteRequest(*Request, any) error
ReadResponseHeader(*Response) error
- ReadResponseBody(interface{}) error
+ ReadResponseBody(any) error
Close() error
}
@@ -214,7 +214,7 @@ type gobClientCodec struct {
encBuf *bufio.Writer
}
-func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) {
+func (c *gobClientCodec) WriteRequest(r *Request, body any) (err error) {
if err = c.enc.Encode(r); err != nil {
return
}
@@ -228,7 +228,7 @@ func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
return c.dec.Decode(r)
}
-func (c *gobClientCodec) ReadResponseBody(body interface{}) error {
+func (c *gobClientCodec) ReadResponseBody(body any) error {
return c.dec.Decode(body)
}
@@ -295,7 +295,7 @@ func (client *Client) Close() error {
// the invocation. The done channel will signal when the call is complete by returning
// the same Call object. If done is nil, Go will allocate a new channel.
// If non-nil, done must be buffered or Go will deliberately crash.
-func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
+func (client *Client) Go(serviceMethod string, args any, reply any, done chan *Call) *Call {
call := new(Call)
call.ServiceMethod = serviceMethod
call.Args = args
@@ -317,7 +317,7 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
}
// Call invokes the named function, waits for it to complete, and returns its error status.
-func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error {
+func (client *Client) Call(serviceMethod string, args any, reply any) error {
call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
return call.Error
}
diff --git a/libgo/go/net/rpc/client_test.go b/libgo/go/net/rpc/client_test.go
index 03225e3..ffc12fa 100644
--- a/libgo/go/net/rpc/client_test.go
+++ b/libgo/go/net/rpc/client_test.go
@@ -17,8 +17,8 @@ type shutdownCodec struct {
closed bool
}
-func (c *shutdownCodec) WriteRequest(*Request, interface{}) error { return nil }
-func (c *shutdownCodec) ReadResponseBody(interface{}) error { return nil }
+func (c *shutdownCodec) WriteRequest(*Request, any) error { return nil }
+func (c *shutdownCodec) ReadResponseBody(any) error { return nil }
func (c *shutdownCodec) ReadResponseHeader(*Response) error {
c.responded <- 1
return errors.New("shutdownCodec ReadResponseHeader")
@@ -57,8 +57,8 @@ func TestGobError(t *testing.T) {
if err == nil {
t.Fatal("no error")
}
- if !strings.Contains(err.(error).Error(), "reading body EOF") {
- t.Fatal("expected `reading body EOF', got", err)
+ if !strings.Contains(err.(error).Error(), "reading body unexpected EOF") {
+ t.Fatal("expected `reading body unexpected EOF', got", err)
}
}()
Register(new(S))
diff --git a/libgo/go/net/rpc/debug.go b/libgo/go/net/rpc/debug.go
index a1d799f..9e499fd 100644
--- a/libgo/go/net/rpc/debug.go
+++ b/libgo/go/net/rpc/debug.go
@@ -72,7 +72,7 @@ type debugHTTP struct {
func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Build a sorted version of the data.
var services serviceArray
- server.serviceMap.Range(func(snamei, svci interface{}) bool {
+ server.serviceMap.Range(func(snamei, svci any) bool {
svc := svci.(*service)
ds := debugService{svc, snamei.(string), make(methodArray, 0, len(svc.method))}
for mname, method := range svc.method {
diff --git a/libgo/go/net/rpc/jsonrpc/all_test.go b/libgo/go/net/rpc/jsonrpc/all_test.go
index 667f839..f4e1278 100644
--- a/libgo/go/net/rpc/jsonrpc/all_test.go
+++ b/libgo/go/net/rpc/jsonrpc/all_test.go
@@ -28,9 +28,9 @@ type Reply struct {
type Arith int
type ArithAddResp struct {
- Id interface{} `json:"id"`
- Result Reply `json:"result"`
- Error interface{} `json:"error"`
+ Id any `json:"id"`
+ Result Reply `json:"result"`
+ Error any `json:"error"`
}
func (t *Arith) Add(args *Args, reply *Reply) error {
diff --git a/libgo/go/net/rpc/jsonrpc/client.go b/libgo/go/net/rpc/jsonrpc/client.go
index e6359be..c473017 100644
--- a/libgo/go/net/rpc/jsonrpc/client.go
+++ b/libgo/go/net/rpc/jsonrpc/client.go
@@ -44,12 +44,12 @@ func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
}
type clientRequest struct {
- Method string `json:"method"`
- Params [1]interface{} `json:"params"`
- Id uint64 `json:"id"`
+ Method string `json:"method"`
+ Params [1]any `json:"params"`
+ Id uint64 `json:"id"`
}
-func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) error {
+func (c *clientCodec) WriteRequest(r *rpc.Request, param any) error {
c.mutex.Lock()
c.pending[r.Seq] = r.ServiceMethod
c.mutex.Unlock()
@@ -62,7 +62,7 @@ func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) error {
type clientResponse struct {
Id uint64 `json:"id"`
Result *json.RawMessage `json:"result"`
- Error interface{} `json:"error"`
+ Error any `json:"error"`
}
func (r *clientResponse) reset() {
@@ -97,7 +97,7 @@ func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error {
return nil
}
-func (c *clientCodec) ReadResponseBody(x interface{}) error {
+func (c *clientCodec) ReadResponseBody(x any) error {
if x == nil {
return nil
}
diff --git a/libgo/go/net/rpc/jsonrpc/server.go b/libgo/go/net/rpc/jsonrpc/server.go
index 40e4e6f..3ee4ddf 100644
--- a/libgo/go/net/rpc/jsonrpc/server.go
+++ b/libgo/go/net/rpc/jsonrpc/server.go
@@ -57,8 +57,8 @@ func (r *serverRequest) reset() {
type serverResponse struct {
Id *json.RawMessage `json:"id"`
- Result interface{} `json:"result"`
- Error interface{} `json:"error"`
+ Result any `json:"result"`
+ Error any `json:"error"`
}
func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error {
@@ -81,7 +81,7 @@ func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error {
return nil
}
-func (c *serverCodec) ReadRequestBody(x interface{}) error {
+func (c *serverCodec) ReadRequestBody(x any) error {
if x == nil {
return nil
}
@@ -92,14 +92,14 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error {
// RPC params is struct.
// Unmarshal into array containing struct for now.
// Should think about making RPC more general.
- var params [1]interface{}
+ var params [1]any
params[0] = x
return json.Unmarshal(*c.req.Params, &params)
}
var null = json.RawMessage([]byte("null"))
-func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error {
+func (c *serverCodec) WriteResponse(r *rpc.Response, x any) error {
c.mutex.Lock()
b, ok := c.pending[r.Seq]
if !ok {
diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go
index 074c5b9..d5207a4 100644
--- a/libgo/go/net/rpc/server.go
+++ b/libgo/go/net/rpc/server.go
@@ -203,7 +203,7 @@ var DefaultServer = NewServer()
// Is this type exported or a builtin?
func isExportedOrBuiltinType(t reflect.Type) bool {
- for t.Kind() == reflect.Ptr {
+ for t.Kind() == reflect.Pointer {
t = t.Elem()
}
// PkgPath will be non-empty even for an exported type,
@@ -221,17 +221,21 @@ func isExportedOrBuiltinType(t reflect.Type) bool {
// no suitable methods. It also logs the error using package log.
// The client accesses each method using a string of the form "Type.Method",
// where Type is the receiver's concrete type.
-func (server *Server) Register(rcvr interface{}) error {
+func (server *Server) Register(rcvr any) error {
return server.register(rcvr, "", false)
}
// RegisterName is like Register but uses the provided name for the type
// instead of the receiver's concrete type.
-func (server *Server) RegisterName(name string, rcvr interface{}) error {
+func (server *Server) RegisterName(name string, rcvr any) error {
return server.register(rcvr, name, true)
}
-func (server *Server) register(rcvr interface{}, name string, useName bool) error {
+// logRegisterError specifies whether to log problems during method registration.
+// To debug registration, recompile the package with this set to true.
+const logRegisterError = false
+
+func (server *Server) register(rcvr any, name string, useName bool) error {
s := new(service)
s.typ = reflect.TypeOf(rcvr)
s.rcvr = reflect.ValueOf(rcvr)
@@ -252,13 +256,13 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
s.name = sname
// Install the methods
- s.method = suitableMethods(s.typ, true)
+ s.method = suitableMethods(s.typ, logRegisterError)
if len(s.method) == 0 {
str := ""
// To help the user, see if a pointer receiver would work.
- method := suitableMethods(reflect.PtrTo(s.typ), false)
+ method := suitableMethods(reflect.PointerTo(s.typ), false)
if len(method) != 0 {
str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
} else {
@@ -274,9 +278,9 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
return nil
}
-// suitableMethods returns suitable Rpc methods of typ, it will report
-// error using log if reportErr is true.
-func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
+// suitableMethods returns suitable Rpc methods of typ. It will log
+// errors if logErr is true.
+func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
methods := make(map[string]*methodType)
for m := 0; m < typ.NumMethod(); m++ {
method := typ.Method(m)
@@ -288,7 +292,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
}
// Method needs three ins: receiver, *args, *reply.
if mtype.NumIn() != 3 {
- if reportErr {
+ if logErr {
log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
}
continue
@@ -296,36 +300,36 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
// First arg need not be a pointer.
argType := mtype.In(1)
if !isExportedOrBuiltinType(argType) {
- if reportErr {
+ if logErr {
log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
}
continue
}
// Second arg must be a pointer.
replyType := mtype.In(2)
- if replyType.Kind() != reflect.Ptr {
- if reportErr {
+ if replyType.Kind() != reflect.Pointer {
+ if logErr {
log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
}
continue
}
// Reply type must be exported.
if !isExportedOrBuiltinType(replyType) {
- if reportErr {
+ if logErr {
log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
}
continue
}
// Method needs one out.
if mtype.NumOut() != 1 {
- if reportErr {
+ if logErr {
log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
}
continue
}
// The return type of the method must be error.
if returnType := mtype.Out(0); returnType != typeOfError {
- if reportErr {
+ if logErr {
log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
}
continue
@@ -340,7 +344,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
// contains an error when it is used.
var invalidRequest = struct{}{}
-func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
+func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) {
resp := server.getResponse()
// Encode the response header
resp.ServiceMethod = req.ServiceMethod
@@ -397,11 +401,11 @@ func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
return c.dec.Decode(r)
}
-func (c *gobServerCodec) ReadRequestBody(body interface{}) error {
+func (c *gobServerCodec) ReadRequestBody(body any) error {
return c.dec.Decode(body)
}
-func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err error) {
+func (c *gobServerCodec) WriteResponse(r *Response, body any) (err error) {
if err = c.enc.Encode(r); err != nil {
if c.encBuf.Flush() == nil {
// Gob couldn't encode the header. Should not happen, so if it does,
@@ -552,7 +556,7 @@ func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *m
// Decode the argument value.
argIsValue := false // if true, need to indirect before calling.
- if mtype.ArgType.Kind() == reflect.Ptr {
+ if mtype.ArgType.Kind() == reflect.Pointer {
argv = reflect.New(mtype.ArgType.Elem())
} else {
argv = reflect.New(mtype.ArgType)
@@ -632,11 +636,11 @@ func (server *Server) Accept(lis net.Listener) {
}
// Register publishes the receiver's methods in the DefaultServer.
-func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
+func Register(rcvr any) error { return DefaultServer.Register(rcvr) }
// RegisterName is like Register but uses the provided name for the type
// instead of the receiver's concrete type.
-func RegisterName(name string, rcvr interface{}) error {
+func RegisterName(name string, rcvr any) error {
return DefaultServer.RegisterName(name, rcvr)
}
@@ -650,8 +654,8 @@ func RegisterName(name string, rcvr interface{}) error {
// See NewClient's comment for information about concurrent access.
type ServerCodec interface {
ReadRequestHeader(*Request) error
- ReadRequestBody(interface{}) error
- WriteResponse(*Response, interface{}) error
+ ReadRequestBody(any) error
+ WriteResponse(*Response, any) error
// Close can be called multiple times and must be idempotent.
Close() error
diff --git a/libgo/go/net/rpc/server_test.go b/libgo/go/net/rpc/server_test.go
index e5d7fe0..dc5f5de 100644
--- a/libgo/go/net/rpc/server_test.go
+++ b/libgo/go/net/rpc/server_test.go
@@ -427,7 +427,7 @@ func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
return nil
}
-func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
+func (codec *CodecEmulator) ReadRequestBody(argv any) error {
if codec.args == nil {
return io.ErrUnexpectedEOF
}
@@ -435,7 +435,7 @@ func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
return nil
}
-func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error {
+func (codec *CodecEmulator) WriteResponse(resp *Response, reply any) error {
if resp.Error != "" {
codec.err = errors.New(resp.Error)
} else {
@@ -521,7 +521,7 @@ func TestRegistrationError(t *testing.T) {
type WriteFailCodec int
-func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
+func (WriteFailCodec) WriteRequest(*Request, any) error {
// the panic caused by this error used to not unlock a lock.
return errors.New("fail")
}
@@ -530,7 +530,7 @@ func (WriteFailCodec) ReadResponseHeader(*Response) error {
select {}
}
-func (WriteFailCodec) ReadResponseBody(interface{}) error {
+func (WriteFailCodec) ReadResponseBody(any) error {
select {}
}
diff --git a/libgo/go/net/sendfile_stub.go b/libgo/go/net/sendfile_stub.go
index 5753bc0..7428da3 100644
--- a/libgo/go/net/sendfile_stub.go
+++ b/libgo/go/net/sendfile_stub.go
@@ -2,8 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build aix || darwin || (js && wasm) || netbsd || openbsd
-// +build aix darwin js,wasm netbsd openbsd
+//go:build aix || (js && wasm) || netbsd || openbsd || ios
package net
diff --git a/libgo/go/net/sendfile_test.go b/libgo/go/net/sendfile_test.go
index 54e51fa..6edfb67 100644
--- a/libgo/go/net/sendfile_test.go
+++ b/libgo/go/net/sendfile_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -28,10 +27,7 @@ const (
)
func TestSendfile(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
errc := make(chan error, 1)
@@ -98,10 +94,7 @@ func TestSendfile(t *testing.T) {
}
func TestSendfileParts(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
errc := make(chan error, 1)
@@ -156,10 +149,7 @@ func TestSendfileParts(t *testing.T) {
}
func TestSendfileSeeked(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
const seekTo = 65 << 10
@@ -226,10 +216,7 @@ func TestSendfilePipe(t *testing.T) {
t.Parallel()
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
r, w, err := os.Pipe()
@@ -318,10 +305,7 @@ func TestSendfilePipe(t *testing.T) {
// Issue 43822: tests that returns EOF when conn write timeout.
func TestSendfileOnWriteTimeoutExceeded(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
errc := make(chan error, 1)
diff --git a/libgo/go/net/sendfile_unix_alt.go b/libgo/go/net/sendfile_unix_alt.go
index 54667d6..f99af92 100644
--- a/libgo/go/net/sendfile_unix_alt.go
+++ b/libgo/go/net/sendfile_unix_alt.go
@@ -2,8 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build dragonfly || freebsd || solaris
-// +build dragonfly freebsd solaris
+//go:build (darwin && !ios) || dragonfly || freebsd || solaris
package net
diff --git a/libgo/go/net/server_test.go b/libgo/go/net/server_test.go
index 7cbf152..6796d79 100644
--- a/libgo/go/net/server_test.go
+++ b/libgo/go/net/server_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -78,10 +77,7 @@ func TestTCPServer(t *testing.T) {
}
}()
for i := 0; i < N; i++ {
- ls, err := (&streamListener{Listener: ln}).newLocalServer()
- if err != nil {
- t.Fatal(err)
- }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
lss = append(lss, ls)
tpchs = append(tpchs, make(chan error, 1))
}
@@ -126,19 +122,19 @@ func TestTCPServer(t *testing.T) {
}
}
-var unixAndUnixpacketServerTests = []struct {
- network, address string
-}{
- {"unix", testUnixAddr()},
- {"unix", "@nettest/go/unix"},
-
- {"unixpacket", testUnixAddr()},
- {"unixpacket", "@nettest/go/unixpacket"},
-}
-
// TestUnixAndUnixpacketServer tests concurrent accept-read-write
// servers
func TestUnixAndUnixpacketServer(t *testing.T) {
+ var unixAndUnixpacketServerTests = []struct {
+ network, address string
+ }{
+ {"unix", testUnixAddr(t)},
+ {"unix", "@nettest/go/unix"},
+
+ {"unixpacket", testUnixAddr(t)},
+ {"unixpacket", "@nettest/go/unixpacket"},
+ }
+
const N = 3
for i, tt := range unixAndUnixpacketServerTests {
@@ -163,10 +159,7 @@ func TestUnixAndUnixpacketServer(t *testing.T) {
}
}()
for i := 0; i < N; i++ {
- ls, err := (&streamListener{Listener: ln}).newLocalServer()
- if err != nil {
- t.Fatal(err)
- }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
lss = append(lss, ls)
tpchs = append(tpchs, make(chan error, 1))
}
@@ -188,7 +181,11 @@ func TestUnixAndUnixpacketServer(t *testing.T) {
}
t.Fatal(err)
}
- defer os.Remove(c.LocalAddr().String())
+
+ if addr := c.LocalAddr(); addr != nil {
+ t.Logf("connected %s->%s", addr, lss[i].Listener.Addr())
+ }
+
defer c.Close()
trchs = append(trchs, make(chan error, 1))
go transceiver(c, []byte("UNIX AND UNIXPACKET SERVER TEST"), trchs[i])
@@ -267,10 +264,7 @@ func TestUDPServer(t *testing.T) {
t.Fatal(err)
}
- ls, err := (&packetListener{PacketConn: c1}).newLocalServer()
- if err != nil {
- t.Fatal(err)
- }
+ ls := (&packetListener{PacketConn: c1}).newLocalServer()
defer ls.teardown()
tpch := make(chan error, 1)
handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) }
@@ -319,18 +313,18 @@ func TestUDPServer(t *testing.T) {
}
}
-var unixgramServerTests = []struct {
- saddr string // server endpoint
- caddr string // client endpoint
- dial bool // test with Dial
-}{
- {saddr: testUnixAddr(), caddr: testUnixAddr()},
- {saddr: testUnixAddr(), caddr: testUnixAddr(), dial: true},
-
- {saddr: "@nettest/go/unixgram/server", caddr: "@nettest/go/unixgram/client"},
-}
-
func TestUnixgramServer(t *testing.T) {
+ var unixgramServerTests = []struct {
+ saddr string // server endpoint
+ caddr string // client endpoint
+ dial bool // test with Dial
+ }{
+ {saddr: testUnixAddr(t), caddr: testUnixAddr(t)},
+ {saddr: testUnixAddr(t), caddr: testUnixAddr(t), dial: true},
+
+ {saddr: "@nettest/go/unixgram/server", caddr: "@nettest/go/unixgram/client"},
+ }
+
for i, tt := range unixgramServerTests {
if !testableListenArgs("unixgram", tt.saddr, "") {
t.Logf("skipping %s test", "unixgram "+tt.saddr+"<-"+tt.caddr)
@@ -345,10 +339,7 @@ func TestUnixgramServer(t *testing.T) {
t.Fatal(err)
}
- ls, err := (&packetListener{PacketConn: c1}).newLocalServer()
- if err != nil {
- t.Fatal(err)
- }
+ ls := (&packetListener{PacketConn: c1}).newLocalServer()
defer ls.teardown()
tpch := make(chan error, 1)
handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) }
diff --git a/libgo/go/net/smtp/smtp.go b/libgo/go/net/smtp/smtp.go
index 1a6864a..c1f00a0 100644
--- a/libgo/go/net/smtp/smtp.go
+++ b/libgo/go/net/smtp/smtp.go
@@ -105,7 +105,7 @@ func (c *Client) Hello(localName string) error {
}
// cmd is a convenience function that sends a command and returns the response
-func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
+func (c *Client) cmd(expectCode int, format string, args ...any) (int, string, error) {
id, err := c.Text.Cmd(format, args...)
if err != nil {
return 0, "", err
@@ -136,12 +136,8 @@ func (c *Client) ehlo() error {
if len(extList) > 1 {
extList = extList[1:]
for _, line := range extList {
- args := strings.SplitN(line, " ", 2)
- if len(args) > 1 {
- ext[args[0]] = args[1]
- } else {
- ext[args[0]] = ""
- }
+ k, v, _ := strings.Cut(line, " ")
+ ext[k] = v
}
}
if mechs, ok := ext["AUTH"]; ok {
diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go
index 5521937..0f758f4 100644
--- a/libgo/go/net/smtp/smtp_test.go
+++ b/libgo/go/net/smtp/smtp_test.go
@@ -948,7 +948,7 @@ QUIT
`
func TestTLSClient(t *testing.T) {
- if (runtime.GOOS == "freebsd" && runtime.GOARCH == "amd64") || runtime.GOOS == "js" {
+ if runtime.GOOS == "freebsd" || runtime.GOOS == "js" {
testenv.SkipFlaky(t, 19229)
}
ln := newLocalListener(t)
diff --git a/libgo/go/net/sock_bsd.go b/libgo/go/net/sock_bsd.go
index 4c883ad..27daf72 100644
--- a/libgo/go/net/sock_bsd.go
+++ b/libgo/go/net/sock_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package net
diff --git a/libgo/go/net/sock_cloexec.go b/libgo/go/net/sock_cloexec.go
index cb57bb4..6321dbc 100644
--- a/libgo/go/net/sock_cloexec.go
+++ b/libgo/go/net/sock_cloexec.go
@@ -6,7 +6,6 @@
// setting SetNonblock and CloseOnExec.
//go:build dragonfly || freebsd || hurd || illumos || linux || netbsd || openbsd
-// +build dragonfly freebsd hurd illumos linux netbsd openbsd
package net
diff --git a/libgo/go/net/sock_posix.go b/libgo/go/net/sock_posix.go
index 8c09b0b..fbdec81 100644
--- a/libgo/go/net/sock_posix.go
+++ b/libgo/go/net/sock_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/sock_stub.go b/libgo/go/net/sock_stub.go
index 1e5032e..e5883d02 100644
--- a/libgo/go/net/sock_stub.go
+++ b/libgo/go/net/sock_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || hurd || (js && wasm) || solaris
-// +build aix hurd js,wasm solaris
package net
diff --git a/libgo/go/net/sockaddr_posix.go b/libgo/go/net/sockaddr_posix.go
index 618d85f..050eac7 100644
--- a/libgo/go/net/sockaddr_posix.go
+++ b/libgo/go/net/sockaddr_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/sockopt_bsd.go b/libgo/go/net/sockopt_bsd.go
index e52fa88..8934e4c 100644
--- a/libgo/go/net/sockopt_bsd.go
+++ b/libgo/go/net/sockopt_bsd.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
-// +build darwin dragonfly freebsd netbsd openbsd
package net
diff --git a/libgo/go/net/sockopt_posix.go b/libgo/go/net/sockopt_posix.go
index 3478872..1d92668 100644
--- a/libgo/go/net/sockopt_posix.go
+++ b/libgo/go/net/sockopt_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/sockopt_stub.go b/libgo/go/net/sockopt_stub.go
index 99b5277..98e2371 100644
--- a/libgo/go/net/sockopt_stub.go
+++ b/libgo/go/net/sockopt_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build js && wasm
-// +build js,wasm
package net
diff --git a/libgo/go/net/sockoptip_bsdvar.go b/libgo/go/net/sockoptip_bsdvar.go
index 8b0b5d2..696fa30 100644
--- a/libgo/go/net/sockoptip_bsdvar.go
+++ b/libgo/go/net/sockoptip_bsdvar.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd hurd netbsd openbsd solaris
package net
diff --git a/libgo/go/net/sockoptip_posix.go b/libgo/go/net/sockoptip_posix.go
index a063e79..3d47afd 100644
--- a/libgo/go/net/sockoptip_posix.go
+++ b/libgo/go/net/sockoptip_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/sockoptip_stub.go b/libgo/go/net/sockoptip_stub.go
index 4175922..2c993eb 100644
--- a/libgo/go/net/sockoptip_stub.go
+++ b/libgo/go/net/sockoptip_stub.go
@@ -3,38 +3,31 @@
// license that can be found in the LICENSE file.
//go:build js && wasm
-// +build js,wasm
package net
import "syscall"
func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
- // See golang.org/issue/7399.
return syscall.ENOPROTOOPT
}
func setIPv4MulticastLoopback(fd *netFD, v bool) error {
- // See golang.org/issue/7399.
return syscall.ENOPROTOOPT
}
func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
- // See golang.org/issue/7399.
return syscall.ENOPROTOOPT
}
func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error {
- // See golang.org/issue/7399.
return syscall.ENOPROTOOPT
}
func setIPv6MulticastLoopback(fd *netFD, v bool) error {
- // See golang.org/issue/7399.
return syscall.ENOPROTOOPT
}
func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
- // See golang.org/issue/7399.
return syscall.ENOPROTOOPT
}
diff --git a/libgo/go/net/splice_stub.go b/libgo/go/net/splice_stub.go
index ce2e904..3cdadb1 100644
--- a/libgo/go/net/splice_stub.go
+++ b/libgo/go/net/splice_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !linux
-// +build !linux
package net
diff --git a/libgo/go/net/splice_test.go b/libgo/go/net/splice_test.go
index d5f6367..5ad9fcd 100644
--- a/libgo/go/net/splice_test.go
+++ b/libgo/go/net/splice_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build linux
-// +build linux
package net
@@ -47,20 +46,14 @@ type spliceTestCase struct {
}
func (tc spliceTestCase) test(t *testing.T) {
- clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
- if err != nil {
- t.Fatal(err)
- }
+ clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
defer serverUp.Close()
cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
if err != nil {
t.Fatal(err)
}
defer cleanup()
- clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
- if err != nil {
- t.Fatal(err)
- }
+ clientDown, serverDown := spliceTestSocketPair(t, tc.downNet)
defer serverDown.Close()
cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
if err != nil {
@@ -104,15 +97,9 @@ func (tc spliceTestCase) test(t *testing.T) {
}
func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
- clientUp, serverUp, err := spliceTestSocketPair(upNet)
- if err != nil {
- t.Fatal(err)
- }
+ clientUp, serverUp := spliceTestSocketPair(t, upNet)
defer clientUp.Close()
- clientDown, serverDown, err := spliceTestSocketPair(downNet)
- if err != nil {
- t.Fatal(err)
- }
+ clientDown, serverDown := spliceTestSocketPair(t, downNet)
defer clientDown.Close()
serverUp.Close()
@@ -141,7 +128,7 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
}()
buf := make([]byte, 3)
- _, err = io.ReadFull(clientDown, buf)
+ _, err := io.ReadFull(clientDown, buf)
if err != nil {
t.Errorf("clientDown: %v", err)
}
@@ -151,15 +138,9 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
}
func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
- front, err := newLocalListener(upNet)
- if err != nil {
- t.Fatal(err)
- }
+ front := newLocalListener(t, upNet)
defer front.Close()
- back, err := newLocalListener(downNet)
- if err != nil {
- t.Fatal(err)
- }
+ back := newLocalListener(t, downNet)
defer back.Close()
var wg sync.WaitGroup
@@ -211,16 +192,10 @@ func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
}
func testSpliceNoUnixpacket(t *testing.T) {
- clientUp, serverUp, err := spliceTestSocketPair("unixpacket")
- if err != nil {
- t.Fatal(err)
- }
+ clientUp, serverUp := spliceTestSocketPair(t, "unixpacket")
defer clientUp.Close()
defer serverUp.Close()
- clientDown, serverDown, err := spliceTestSocketPair("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ clientDown, serverDown := spliceTestSocketPair(t, "tcp")
defer clientDown.Close()
defer serverDown.Close()
// If splice called poll.Splice here, we'd get err == syscall.EINVAL
@@ -238,7 +213,7 @@ func testSpliceNoUnixpacket(t *testing.T) {
}
func testSpliceNoUnixgram(t *testing.T) {
- addr, err := ResolveUnixAddr("unixgram", testUnixAddr())
+ addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
if err != nil {
t.Fatal(err)
}
@@ -248,10 +223,7 @@ func testSpliceNoUnixgram(t *testing.T) {
t.Fatal(err)
}
defer up.Close()
- clientDown, serverDown, err := spliceTestSocketPair("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ clientDown, serverDown := spliceTestSocketPair(t, "tcp")
defer clientDown.Close()
defer serverDown.Close()
// Analogous to testSpliceNoUnixpacket.
@@ -285,10 +257,7 @@ func (tc spliceTestCase) bench(b *testing.B) {
// To benchmark the genericReadFrom code path, set this to false.
useSplice := true
- clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
- if err != nil {
- b.Fatal(err)
- }
+ clientUp, serverUp := spliceTestSocketPair(b, tc.upNet)
defer serverUp.Close()
cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
@@ -297,10 +266,7 @@ func (tc spliceTestCase) bench(b *testing.B) {
}
defer cleanup()
- clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
- if err != nil {
- b.Fatal(err)
- }
+ clientDown, serverDown := spliceTestSocketPair(b, tc.downNet)
defer serverDown.Close()
cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
@@ -328,11 +294,9 @@ func (tc spliceTestCase) bench(b *testing.B) {
}
}
-func spliceTestSocketPair(net string) (client, server Conn, err error) {
- ln, err := newLocalListener(net)
- if err != nil {
- return nil, nil, err
- }
+func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) {
+ t.Helper()
+ ln := newLocalListener(t, net)
defer ln.Close()
var cerr, serr error
acceptDone := make(chan struct{})
@@ -346,15 +310,15 @@ func spliceTestSocketPair(net string) (client, server Conn, err error) {
if server != nil {
server.Close()
}
- return nil, nil, cerr
+ t.Fatal(cerr)
}
if serr != nil {
if client != nil {
client.Close()
}
- return nil, nil, serr
+ t.Fatal(serr)
}
- return client, server, nil
+ return client, server
}
func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
diff --git a/libgo/go/net/sys_cloexec.go b/libgo/go/net/sys_cloexec.go
index a32483e..26eac55 100644
--- a/libgo/go/net/sys_cloexec.go
+++ b/libgo/go/net/sys_cloexec.go
@@ -6,7 +6,6 @@
// for setting SetNonblock and CloseOnExec.
//go:build aix || darwin || (solaris && !illumos)
-// +build aix darwin solaris,!illumos
package net
diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go
index 19a90143..6bad0e8 100644
--- a/libgo/go/net/tcpsock.go
+++ b/libgo/go/net/tcpsock.go
@@ -8,6 +8,7 @@ import (
"context"
"internal/itoa"
"io"
+ "net/netip"
"os"
"syscall"
"time"
@@ -23,6 +24,20 @@ type TCPAddr struct {
Zone string // IPv6 scoped addressing zone
}
+// AddrPort returns the TCPAddr a as a netip.AddrPort.
+//
+// If a.Port does not fit in a uint16, it's silently truncated.
+//
+// If a is nil, a zero value is returned.
+func (a *TCPAddr) AddrPort() netip.AddrPort {
+ if a == nil {
+ return netip.AddrPort{}
+ }
+ na, _ := netip.AddrFromSlice(a.IP)
+ na = na.WithZone(a.Zone)
+ return netip.AddrPortFrom(na, uint16(a.Port))
+}
+
// Network returns the address's network name, "tcp".
func (a *TCPAddr) Network() string { return "tcp" }
@@ -81,6 +96,17 @@ func ResolveTCPAddr(network, address string) (*TCPAddr, error) {
return addrs.forResolve(network, address).(*TCPAddr), nil
}
+// TCPAddrFromAddrPort returns addr as a TCPAddr. If addr.IsValid() is false,
+// then the returned TCPAddr will contain a nil IP field, indicating an
+// address family-agnostic unspecified address.
+func TCPAddrFromAddrPort(addr netip.AddrPort) *TCPAddr {
+ return &TCPAddr{
+ IP: addr.Addr().AsSlice(),
+ Zone: addr.Addr().Zone(),
+ Port: int(addr.Port()),
+ }
+}
+
// TCPConn is an implementation of the Conn interface for TCP network
// connections.
type TCPConn struct {
diff --git a/libgo/go/net/tcpsock_posix.go b/libgo/go/net/tcpsock_posix.go
index 9fd7822..8237909 100644
--- a/libgo/go/net/tcpsock_posix.go
+++ b/libgo/go/net/tcpsock_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/tcpsock_test.go b/libgo/go/net/tcpsock_test.go
index 884c5cb..5cff961 100644
--- a/libgo/go/net/tcpsock_test.go
+++ b/libgo/go/net/tcpsock_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -388,10 +387,7 @@ func TestIPv6LinkLocalUnicastTCP(t *testing.T) {
t.Log(err)
continue
}
- ls, err := (&streamListener{Listener: ln}).newLocalServer()
- if err != nil {
- t.Fatal(err)
- }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
defer ls.teardown()
ch := make(chan error, 1)
handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) }
@@ -632,10 +628,7 @@ func TestTCPSelfConnect(t *testing.T) {
t.Skip("known-broken test on windows")
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
var d Dialer
c, err := d.Dial(ln.Addr().Network(), ln.Addr().String())
if err != nil {
@@ -682,10 +675,7 @@ func TestTCPBig(t *testing.T) {
for _, writev := range []bool{false, true} {
t.Run(fmt.Sprintf("writev=%v", writev), func(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
x := int(1 << 30)
@@ -729,10 +719,7 @@ func TestTCPBig(t *testing.T) {
}
func TestCopyPipeIntoTCP(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
errc := make(chan error, 1)
@@ -800,10 +787,7 @@ func TestCopyPipeIntoTCP(t *testing.T) {
}
func BenchmarkSetReadDeadline(b *testing.B) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- b.Fatal(err)
- }
+ ln := newLocalListener(b, "tcp")
defer ln.Close()
var serv Conn
done := make(chan error)
diff --git a/libgo/go/net/tcpsock_unix_test.go b/libgo/go/net/tcpsock_unix_test.go
index 41bd229..b14670b 100644
--- a/libgo/go/net/tcpsock_unix_test.go
+++ b/libgo/go/net/tcpsock_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js && !plan9 && !windows
-// +build !js,!plan9,!windows
package net
@@ -23,10 +22,7 @@ func TestTCPSpuriousConnSetupCompletion(t *testing.T) {
t.Skip("skipping in short mode")
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
var wg sync.WaitGroup
wg.Add(1)
go func(ln Listener) {
diff --git a/libgo/go/net/tcpsockopt_posix.go b/libgo/go/net/tcpsockopt_posix.go
index 4c99ab8..ad54d1b 100644
--- a/libgo/go/net/tcpsockopt_posix.go
+++ b/libgo/go/net/tcpsockopt_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/tcpsockopt_stub.go b/libgo/go/net/tcpsockopt_stub.go
index 028d5fd..0fe9182 100644
--- a/libgo/go/net/tcpsockopt_stub.go
+++ b/libgo/go/net/tcpsockopt_stub.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build js && wasm
-// +build js,wasm
package net
diff --git a/libgo/go/net/tcpsockopt_unix.go b/libgo/go/net/tcpsockopt_unix.go
index cc0662a..edcab44 100644
--- a/libgo/go/net/tcpsockopt_unix.go
+++ b/libgo/go/net/tcpsockopt_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || freebsd || hurd || linux || netbsd
-// +build aix freebsd hurd linux netbsd
package net
diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go
index 5c3084f..157c59b 100644
--- a/libgo/go/net/textproto/reader.go
+++ b/libgo/go/net/textproto/reader.go
@@ -460,6 +460,8 @@ func (r *Reader) ReadDotLines() ([]string, error) {
return v, err
}
+var colon = []byte(":")
+
// ReadMIMEHeader reads a MIME-style header from r.
// The header is a sequence of possibly continued Key: Value lines
// ending in a blank line.
@@ -508,11 +510,11 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
}
// Key ends at first colon.
- i := bytes.IndexByte(kv, ':')
- if i < 0 {
+ k, v, ok := bytes.Cut(kv, colon)
+ if !ok {
return m, ProtocolError("malformed MIME header line: " + string(kv))
}
- key := canonicalMIMEHeaderKey(kv[:i])
+ key := canonicalMIMEHeaderKey(k)
// As per RFC 7230 field-name is a token, tokens consist of one or more chars.
// We could return a ProtocolError here, but better to be liberal in what we
@@ -522,11 +524,7 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
}
// Skip initial spaces in value.
- i++ // skip colon
- for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') {
- i++
- }
- value := string(kv[i:])
+ value := strings.TrimLeft(string(v), " \t")
vv := m[key]
if vv == nil && len(strs) > 0 {
@@ -561,6 +559,8 @@ func mustHaveFieldNameColon(line []byte) error {
return nil
}
+var nl = []byte("\n")
+
// upcomingHeaderNewlines returns an approximation of the number of newlines
// that will be in this header. If it gets confused, it returns 0.
func (r *Reader) upcomingHeaderNewlines() (n int) {
@@ -571,17 +571,7 @@ func (r *Reader) upcomingHeaderNewlines() (n int) {
return
}
peek, _ := r.R.Peek(s)
- for len(peek) > 0 {
- i := bytes.IndexByte(peek, '\n')
- if i < 3 {
- // Not present (-1) or found within the next few bytes,
- // implying we're at the end ("\r\n\r\n" or "\n\n")
- return
- }
- n++
- peek = peek[i+1:]
- }
- return
+ return bytes.Count(peek, nl)
}
// CanonicalMIMEHeaderKey returns the canonical format of the
diff --git a/libgo/go/net/textproto/textproto.go b/libgo/go/net/textproto/textproto.go
index 8fd781e..cc1a847 100644
--- a/libgo/go/net/textproto/textproto.go
+++ b/libgo/go/net/textproto/textproto.go
@@ -111,7 +111,7 @@ func Dial(network, addr string) (*Conn, error) {
// }
// return c.ReadCodeLine(250)
//
-func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err error) {
+func (c *Conn) Cmd(format string, args ...any) (id uint, err error) {
id = c.Next()
c.StartRequest(id)
err = c.PrintfLine(format, args...)
diff --git a/libgo/go/net/textproto/writer.go b/libgo/go/net/textproto/writer.go
index 33c146c..2ece3f5 100644
--- a/libgo/go/net/textproto/writer.go
+++ b/libgo/go/net/textproto/writer.go
@@ -26,7 +26,7 @@ var crnl = []byte{'\r', '\n'}
var dotcrnl = []byte{'.', '\r', '\n'}
// PrintfLine writes the formatted output followed by \r\n.
-func (w *Writer) PrintfLine(format string, args ...interface{}) error {
+func (w *Writer) PrintfLine(format string, args ...any) error {
w.closeDot()
fmt.Fprintf(w.W, format, args...)
w.W.Write(crnl)
diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go
index e1cf146..d1cfbf8 100644
--- a/libgo/go/net/timeout_test.go
+++ b/libgo/go/net/timeout_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -93,53 +92,35 @@ func TestDialTimeout(t *testing.T) {
}
}
-var dialTimeoutMaxDurationTests = []struct {
- timeout time.Duration
- delta time.Duration // for deadline
-}{
- // Large timeouts that will overflow an int64 unix nanos.
- {1<<63 - 1, 0},
- {0, 1<<63 - 1},
-}
-
func TestDialTimeoutMaxDuration(t *testing.T) {
- if runtime.GOOS == "openbsd" {
- testenv.SkipFlaky(t, 15157)
- }
-
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
- defer ln.Close()
+ ln := newLocalListener(t, "tcp")
+ defer func() {
+ if err := ln.Close(); err != nil {
+ t.Error(err)
+ }
+ }()
- for i, tt := range dialTimeoutMaxDurationTests {
- ch := make(chan error)
- max := time.NewTimer(250 * time.Millisecond)
- defer max.Stop()
- go func() {
+ for _, tt := range []struct {
+ timeout time.Duration
+ delta time.Duration // for deadline
+ }{
+ // Large timeouts that will overflow an int64 unix nanos.
+ {1<<63 - 1, 0},
+ {0, 1<<63 - 1},
+ } {
+ t.Run(fmt.Sprintf("timeout=%s/delta=%s", tt.timeout, tt.delta), func(t *testing.T) {
d := Dialer{Timeout: tt.timeout}
if tt.delta != 0 {
d.Deadline = time.Now().Add(tt.delta)
}
c, err := d.Dial(ln.Addr().Network(), ln.Addr().String())
- if err == nil {
- c.Close()
- }
- ch <- err
- }()
-
- select {
- case <-max.C:
- t.Fatalf("#%d: Dial didn't return in an expected time", i)
- case err := <-ch:
- if perr := parseDialError(err); perr != nil {
- t.Error(perr)
- }
if err != nil {
- t.Errorf("#%d: %v", i, err)
+ t.Fatal(err)
}
- }
+ if err := c.Close(); err != nil {
+ t.Error(err)
+ }
+ })
}
}
@@ -163,10 +144,7 @@ func TestAcceptTimeout(t *testing.T) {
t.Skipf("not supported on %s", runtime.GOOS)
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
var wg sync.WaitGroup
@@ -219,10 +197,7 @@ func TestAcceptTimeoutMustReturn(t *testing.T) {
t.Skipf("not supported on %s", runtime.GOOS)
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
max := time.NewTimer(time.Second)
@@ -265,10 +240,7 @@ func TestAcceptTimeoutMustNotReturn(t *testing.T) {
t.Skipf("not supported on %s", runtime.GOOS)
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
max := time.NewTimer(100 * time.Millisecond)
@@ -318,10 +290,7 @@ func TestReadTimeout(t *testing.T) {
c.Write([]byte("READ TIMEOUT TEST"))
defer c.Close()
}
- ls, err := newLocalServer("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, "tcp")
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -370,10 +339,7 @@ func TestReadTimeoutMustNotReturn(t *testing.T) {
t.Skipf("not supported on %s", runtime.GOOS)
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
c, err := Dial(ln.Addr().Network(), ln.Addr().String())
@@ -437,10 +403,7 @@ func TestReadFromTimeout(t *testing.T) {
c.WriteTo([]byte("READFROM TIMEOUT TEST"), dst)
}
}
- ls, err := newLocalPacketServer("udp")
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalPacketServer(t, "udp")
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -500,10 +463,7 @@ var writeTimeoutTests = []struct {
func TestWriteTimeout(t *testing.T) {
t.Parallel()
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
for i, tt := range writeTimeoutTests {
@@ -548,10 +508,7 @@ func TestWriteTimeoutMustNotReturn(t *testing.T) {
t.Skipf("not supported on %s", runtime.GOOS)
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
c, err := Dial(ln.Addr().Network(), ln.Addr().String())
@@ -600,24 +557,10 @@ func TestWriteTimeoutMustNotReturn(t *testing.T) {
}
}
-var writeToTimeoutTests = []struct {
- timeout time.Duration
- xerrs [2]error // expected errors in transition
-}{
- // Tests that write deadlines work, even if there's buffer
- // space available to write.
- {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}},
-
- {10 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}},
-}
-
func TestWriteToTimeout(t *testing.T) {
t.Parallel()
- c1, err := newLocalPacketListener("udp")
- if err != nil {
- t.Fatal(err)
- }
+ c1 := newLocalPacketListener(t, "udp")
defer c1.Close()
host, _, err := SplitHostPort(c1.LocalAddr().String())
@@ -625,47 +568,116 @@ func TestWriteToTimeout(t *testing.T) {
t.Fatal(err)
}
- for i, tt := range writeToTimeoutTests {
- c2, err := ListenPacket(c1.LocalAddr().Network(), JoinHostPort(host, "0"))
- if err != nil {
- t.Fatal(err)
- }
- defer c2.Close()
+ timeouts := []time.Duration{
+ -5 * time.Second,
+ 10 * time.Millisecond,
+ }
- if err := c2.SetWriteDeadline(time.Now().Add(tt.timeout)); err != nil {
- t.Fatalf("#%d: %v", i, err)
- }
- for j, xerr := range tt.xerrs {
- for {
+ for _, timeout := range timeouts {
+ t.Run(fmt.Sprint(timeout), func(t *testing.T) {
+ c2, err := ListenPacket(c1.LocalAddr().Network(), JoinHostPort(host, "0"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ if err := c2.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
+ t.Fatalf("SetWriteDeadline: %v", err)
+ }
+ backoff := 1 * time.Millisecond
+ nDeadlineExceeded := 0
+ for j := 0; nDeadlineExceeded < 2; j++ {
n, err := c2.WriteTo([]byte("WRITETO TIMEOUT TEST"), c1.LocalAddr())
- if xerr != nil {
- if perr := parseWriteError(err); perr != nil {
- t.Errorf("#%d/%d: %v", i, j, perr)
- }
- if !isDeadlineExceeded(err) {
- t.Fatalf("#%d/%d: %v", i, j, err)
- }
+ t.Logf("#%d: WriteTo: %d, %v", j, n, err)
+ if err == nil && timeout >= 0 && nDeadlineExceeded == 0 {
+ // If the timeout is nonnegative, some number of WriteTo calls may
+ // succeed before the timeout takes effect.
+ t.Logf("WriteTo succeeded; sleeping %v", timeout/3)
+ time.Sleep(timeout / 3)
+ continue
}
- if err == nil {
- time.Sleep(tt.timeout / 3)
+ if isENOBUFS(err) {
+ t.Logf("WriteTo: %v", err)
+ // We're looking for a deadline exceeded error, but if the kernel's
+ // network buffers are saturated we may see ENOBUFS instead (see
+ // https://go.dev/issue/49930). Give it some time to unsaturate.
+ time.Sleep(backoff)
+ backoff *= 2
continue
}
+ if perr := parseWriteError(err); perr != nil {
+ t.Errorf("failed to parse error: %v", perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Errorf("error is not 'deadline exceeded'")
+ }
if n != 0 {
- t.Fatalf("#%d/%d: wrote %d; want 0", i, j, n)
+ t.Errorf("unexpectedly wrote %d bytes", n)
}
- break
+ if !t.Failed() {
+ t.Logf("WriteTo timed out as expected")
+ }
+ nDeadlineExceeded++
}
- }
+ })
}
}
-func TestReadTimeoutFluctuation(t *testing.T) {
- t.Parallel()
+const (
+ // minDynamicTimeout is the minimum timeout to attempt for
+ // tests that automatically increase timeouts until success.
+ //
+ // Lower values may allow tests to succeed more quickly if the value is close
+ // to the true minimum, but may require more iterations (and waste more time
+ // and CPU power on failed attempts) if the timeout is too low.
+ minDynamicTimeout = 1 * time.Millisecond
+
+ // maxDynamicTimeout is the maximum timeout to attempt for
+ // tests that automatically increase timeouts until succeess.
+ //
+ // This should be a strict upper bound on the latency required to hit a
+ // timeout accurately, even on a slow or heavily-loaded machine. If a test
+ // would increase the timeout beyond this value, the test fails.
+ maxDynamicTimeout = 4 * time.Second
+)
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
+// timeoutUpperBound returns the maximum time that we expect a timeout of
+// duration d to take to return the caller.
+func timeoutUpperBound(d time.Duration) time.Duration {
+ switch runtime.GOOS {
+ case "openbsd", "netbsd":
+ // NetBSD and OpenBSD seem to be unable to reliably hit deadlines even when
+ // the absolute durations are long.
+ // In https://build.golang.org/log/c34f8685d020b98377dd4988cd38f0c5bd72267e,
+ // we observed that an openbsd-amd64-68 builder took 4.090948779s for a
+ // 2.983020682s timeout (37.1% overhead).
+ // (See https://go.dev/issue/50189 for further detail.)
+ // Give them lots of slop to compensate.
+ return d * 3 / 2
+ }
+ // Other platforms seem to hit their deadlines more reliably,
+ // at least when they are long enough to cover scheduling jitter.
+ return d * 11 / 10
+}
+
+// nextTimeout returns the next timeout to try after an operation took the given
+// actual duration with a timeout shorter than that duration.
+func nextTimeout(actual time.Duration) (next time.Duration, ok bool) {
+ if actual >= maxDynamicTimeout {
+ return maxDynamicTimeout, false
+ }
+ // Since the previous attempt took actual, we can't expect to beat that
+ // duration by any significant margin. Try the next attempt with an arbitrary
+ // factor above that, so that our growth curve is at least exponential.
+ next = actual * 5 / 4
+ if next > maxDynamicTimeout {
+ return maxDynamicTimeout, true
}
+ return next, true
+}
+
+func TestReadTimeoutFluctuation(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
c, err := Dial(ln.Addr().Network(), ln.Addr().String())
@@ -674,31 +686,54 @@ func TestReadTimeoutFluctuation(t *testing.T) {
}
defer c.Close()
- max := time.NewTimer(time.Second)
- defer max.Stop()
- ch := make(chan error)
- go timeoutReceiver(c, 100*time.Millisecond, 50*time.Millisecond, 250*time.Millisecond, ch)
+ d := minDynamicTimeout
+ b := make([]byte, 256)
+ for {
+ t.Logf("SetReadDeadline(+%v)", d)
+ t0 := time.Now()
+ deadline := t0.Add(d)
+ if err = c.SetReadDeadline(deadline); err != nil {
+ t.Fatalf("SetReadDeadline(%v): %v", deadline, err)
+ }
+ var n int
+ n, err = c.Read(b)
+ t1 := time.Now()
- select {
- case <-max.C:
- t.Fatal("Read took over 1s; expected 0.1s")
- case err := <-ch:
+ if n != 0 || err == nil || !err.(Error).Timeout() {
+ t.Errorf("Read did not return (0, timeout): (%d, %v)", n, err)
+ }
if perr := parseReadError(err); perr != nil {
t.Error(perr)
}
if !isDeadlineExceeded(err) {
- t.Fatal(err)
+ t.Errorf("Read error is not DeadlineExceeded: %v", err)
}
+
+ actual := t1.Sub(t0)
+ if t1.Before(deadline) {
+ t.Errorf("Read took %s; expected at least %s", actual, d)
+ }
+ if t.Failed() {
+ return
+ }
+ if want := timeoutUpperBound(d); actual > want {
+ next, ok := nextTimeout(actual)
+ if !ok {
+ t.Fatalf("Read took %s; expected at most %v", actual, want)
+ }
+ // Maybe this machine is too slow to reliably schedule goroutines within
+ // the requested duration. Increase the timeout and try again.
+ t.Logf("Read took %s (expected %s); trying with longer timeout", actual, d)
+ d = next
+ continue
+ }
+
+ break
}
}
func TestReadFromTimeoutFluctuation(t *testing.T) {
- t.Parallel()
-
- c1, err := newLocalPacketListener("udp")
- if err != nil {
- t.Fatal(err)
- }
+ c1 := newLocalPacketListener(t, "udp")
defer c1.Close()
c2, err := Dial(c1.LocalAddr().Network(), c1.LocalAddr().String())
@@ -707,36 +742,59 @@ func TestReadFromTimeoutFluctuation(t *testing.T) {
}
defer c2.Close()
- max := time.NewTimer(time.Second)
- defer max.Stop()
- ch := make(chan error)
- go timeoutPacketReceiver(c2.(PacketConn), 100*time.Millisecond, 50*time.Millisecond, 250*time.Millisecond, ch)
+ d := minDynamicTimeout
+ b := make([]byte, 256)
+ for {
+ t.Logf("SetReadDeadline(+%v)", d)
+ t0 := time.Now()
+ deadline := t0.Add(d)
+ if err = c2.SetReadDeadline(deadline); err != nil {
+ t.Fatalf("SetReadDeadline(%v): %v", deadline, err)
+ }
+ var n int
+ n, _, err = c2.(PacketConn).ReadFrom(b)
+ t1 := time.Now()
- select {
- case <-max.C:
- t.Fatal("ReadFrom took over 1s; expected 0.1s")
- case err := <-ch:
+ if n != 0 || err == nil || !err.(Error).Timeout() {
+ t.Errorf("ReadFrom did not return (0, timeout): (%d, %v)", n, err)
+ }
if perr := parseReadError(err); perr != nil {
t.Error(perr)
}
if !isDeadlineExceeded(err) {
- t.Fatal(err)
+ t.Errorf("ReadFrom error is not DeadlineExceeded: %v", err)
+ }
+
+ actual := t1.Sub(t0)
+ if t1.Before(deadline) {
+ t.Errorf("ReadFrom took %s; expected at least %s", actual, d)
+ }
+ if t.Failed() {
+ return
+ }
+ if want := timeoutUpperBound(d); actual > want {
+ next, ok := nextTimeout(actual)
+ if !ok {
+ t.Fatalf("ReadFrom took %s; expected at most %s", actual, want)
+ }
+ // Maybe this machine is too slow to reliably schedule goroutines within
+ // the requested duration. Increase the timeout and try again.
+ t.Logf("ReadFrom took %s (expected %s); trying with longer timeout", actual, d)
+ d = next
+ continue
}
+
+ break
}
}
func TestWriteTimeoutFluctuation(t *testing.T) {
- t.Parallel()
-
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
c, err := Dial(ln.Addr().Network(), ln.Addr().String())
@@ -745,25 +803,67 @@ func TestWriteTimeoutFluctuation(t *testing.T) {
}
defer c.Close()
- d := time.Second
- if iOS() {
- d = 3 * time.Second // see golang.org/issue/10775
- }
- max := time.NewTimer(d)
- defer max.Stop()
- ch := make(chan error)
- go timeoutTransmitter(c, 100*time.Millisecond, 50*time.Millisecond, 250*time.Millisecond, ch)
+ d := minDynamicTimeout
+ for {
+ t.Logf("SetWriteDeadline(+%v)", d)
+ t0 := time.Now()
+ deadline := t0.Add(d)
+ if err = c.SetWriteDeadline(deadline); err != nil {
+ t.Fatalf("SetWriteDeadline(%v): %v", deadline, err)
+ }
+ var n int64
+ for {
+ var dn int
+ dn, err = c.Write([]byte("TIMEOUT TRANSMITTER"))
+ n += int64(dn)
+ if err != nil {
+ break
+ }
+ }
+ t1 := time.Now()
- select {
- case <-max.C:
- t.Fatalf("Write took over %v; expected 0.1s", d)
- case err := <-ch:
+ if err == nil || !err.(Error).Timeout() {
+ t.Fatalf("Write did not return (any, timeout): (%d, %v)", n, err)
+ }
if perr := parseWriteError(err); perr != nil {
t.Error(perr)
}
if !isDeadlineExceeded(err) {
- t.Fatal(err)
+ t.Errorf("Write error is not DeadlineExceeded: %v", err)
+ }
+
+ actual := t1.Sub(t0)
+ if t1.Before(deadline) {
+ t.Errorf("Write took %s; expected at least %s", actual, d)
}
+ if t.Failed() {
+ return
+ }
+ if want := timeoutUpperBound(d); actual > want {
+ if n > 0 {
+ // SetWriteDeadline specifies a time โ€œafter which I/O operations fail
+ // instead of blockingโ€. However, the kernel's send buffer is not yet
+ // full, we may be able to write some arbitrary (but finite) number of
+ // bytes to it without blocking.
+ t.Logf("Wrote %d bytes into send buffer; retrying until buffer is full", n)
+ if d <= maxDynamicTimeout/2 {
+ // We don't know how long the actual write loop would have taken if
+ // the buffer were full, so just guess and double the duration so that
+ // the next attempt can make twice as much progress toward filling it.
+ d *= 2
+ }
+ } else if next, ok := nextTimeout(actual); !ok {
+ t.Fatalf("Write took %s; expected at most %s", actual, want)
+ } else {
+ // Maybe this machine is too slow to reliably schedule goroutines within
+ // the requested duration. Increase the timeout and try again.
+ t.Logf("Write took %s (expected %s); trying with longer timeout", actual, d)
+ d = next
+ }
+ continue
+ }
+
+ break
}
}
@@ -819,10 +919,7 @@ func testVariousDeadlines(t *testing.T) {
c.Close()
}
}
- ls, err := newLocalServer("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, "tcp")
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -860,35 +957,23 @@ func testVariousDeadlines(t *testing.T) {
name := fmt.Sprintf("%v %d/%d", timeout, run, numRuns)
t.Log(name)
- tooSlow := time.NewTimer(5 * time.Second)
- defer tooSlow.Stop()
-
c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
- ch := make(chan result, 1)
- go func() {
- t0 := time.Now()
- if err := c.SetDeadline(t0.Add(timeout)); err != nil {
- t.Error(err)
- }
- n, err := io.Copy(io.Discard, c)
- dt := time.Since(t0)
- c.Close()
- ch <- result{n, err, dt}
- }()
+ t0 := time.Now()
+ if err := c.SetDeadline(t0.Add(timeout)); err != nil {
+ t.Error(err)
+ }
+ n, err := io.Copy(io.Discard, c)
+ dt := time.Since(t0)
+ c.Close()
- select {
- case res := <-ch:
- if nerr, ok := res.err.(Error); ok && nerr.Timeout() {
- t.Logf("%v: good timeout after %v; %d bytes", name, res.d, res.n)
- } else {
- t.Fatalf("%v: Copy = %d, %v; want timeout", name, res.n, res.err)
- }
- case <-tooSlow.C:
- t.Fatalf("%v: client stuck in Dial+Copy", name)
+ if nerr, ok := err.(Error); ok && nerr.Timeout() {
+ t.Logf("%v: good timeout after %v; %d bytes", name, dt, n)
+ } else {
+ t.Fatalf("%v: Copy = %d, %v; want timeout", name, n, err)
}
}
}
@@ -954,10 +1039,7 @@ func TestReadWriteProlongedTimeout(t *testing.T) {
}()
wg.Wait()
}
- ls, err := newLocalServer("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ls := newLocalServer(t, "tcp")
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -984,10 +1066,7 @@ func TestReadWriteDeadlineRace(t *testing.T) {
N = 50
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
c, err := Dial(ln.Addr().Network(), ln.Addr().String())
@@ -1037,10 +1116,7 @@ func TestReadWriteDeadlineRace(t *testing.T) {
// Issue 35367.
func TestConcurrentSetDeadline(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
const goroutines = 8
@@ -1049,6 +1125,7 @@ func TestConcurrentSetDeadline(t *testing.T) {
var c [conns]Conn
for i := 0; i < conns; i++ {
+ var err error
c[i], err = Dial(ln.Addr().Network(), ln.Addr().String())
if err != nil {
t.Fatal(err)
diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go
index 70f2ce2..6d29a39 100644
--- a/libgo/go/net/udpsock.go
+++ b/libgo/go/net/udpsock.go
@@ -7,6 +7,7 @@ package net
import (
"context"
"internal/itoa"
+ "net/netip"
"syscall"
)
@@ -26,6 +27,20 @@ type UDPAddr struct {
Zone string // IPv6 scoped addressing zone
}
+// AddrPort returns the UDPAddr a as a netip.AddrPort.
+//
+// If a.Port does not fit in a uint16, it's silently truncated.
+//
+// If a is nil, a zero value is returned.
+func (a *UDPAddr) AddrPort() netip.AddrPort {
+ if a == nil {
+ return netip.AddrPort{}
+ }
+ na, _ := netip.AddrFromSlice(a.IP)
+ na = na.WithZone(a.Zone)
+ return netip.AddrPortFrom(na, uint16(a.Port))
+}
+
// Network returns the address's network name, "udp".
func (a *UDPAddr) Network() string { return "udp" }
@@ -84,6 +99,24 @@ func ResolveUDPAddr(network, address string) (*UDPAddr, error) {
return addrs.forResolve(network, address).(*UDPAddr), nil
}
+// UDPAddrFromAddrPort returns addr as a UDPAddr. If addr.IsValid() is false,
+// then the returned UDPAddr will contain a nil IP field, indicating an
+// address family-agnostic unspecified address.
+func UDPAddrFromAddrPort(addr netip.AddrPort) *UDPAddr {
+ return &UDPAddr{
+ IP: addr.Addr().AsSlice(),
+ Zone: addr.Addr().Zone(),
+ Port: int(addr.Port()),
+ }
+}
+
+// An addrPortUDPAddr is a netip.AddrPort-based UDP address that satisfies the Addr interface.
+type addrPortUDPAddr struct {
+ netip.AddrPort
+}
+
+func (addrPortUDPAddr) Network() string { return "udp" }
+
// UDPConn is the implementation of the Conn and PacketConn interfaces
// for UDP network connections.
type UDPConn struct {
@@ -130,6 +163,18 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) {
return n, addr, err
}
+// ReadFromUDPAddrPort acts like ReadFrom but returns a netip.AddrPort.
+func (c *UDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
+ if !c.ok() {
+ return 0, netip.AddrPort{}, syscall.EINVAL
+ }
+ n, addr, err = c.readFromAddrPort(b)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, addr, err
+}
+
// ReadMsgUDP reads a message from c, copying the payload into b and
// the associated out-of-band data into oob. It returns the number of
// bytes copied into b, the number of bytes copied into oob, the flags
@@ -138,8 +183,18 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) {
// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be
// used to manipulate IP-level socket options in oob.
func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
+ var ap netip.AddrPort
+ n, oobn, flags, ap, err = c.ReadMsgUDPAddrPort(b, oob)
+ if ap.IsValid() {
+ addr = UDPAddrFromAddrPort(ap)
+ }
+ return
+}
+
+// ReadMsgUDPAddrPort is like ReadMsgUDP but returns an netip.AddrPort instead of a UDPAddr.
+func (c *UDPConn) ReadMsgUDPAddrPort(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
if !c.ok() {
- return 0, 0, 0, nil, syscall.EINVAL
+ return 0, 0, 0, netip.AddrPort{}, syscall.EINVAL
}
n, oobn, flags, addr, err = c.readMsg(b, oob)
if err != nil {
@@ -160,6 +215,18 @@ func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) {
return n, err
}
+// WriteToUDPAddrPort acts like WriteTo but takes a netip.AddrPort.
+func (c *UDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.writeToAddrPort(b, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addrPortUDPAddr{addr}, Err: err}
+ }
+ return n, err
+}
+
// WriteTo implements the PacketConn WriteTo method.
func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) {
if !c.ok() {
@@ -195,6 +262,18 @@ func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err er
return
}
+// WriteMsgUDPAddrPort is like WriteMsgUDP but takes a netip.AddrPort instead of a UDPAddr.
+func (c *UDPConn) WriteMsgUDPAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) {
+ if !c.ok() {
+ return 0, 0, syscall.EINVAL
+ }
+ n, oobn, err = c.writeMsgAddrPort(b, oob, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addrPortUDPAddr{addr}, Err: err}
+ }
+ return
+}
+
func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
// DialUDP acts like Dial for UDP networks.
diff --git a/libgo/go/net/udpsock_plan9.go b/libgo/go/net/udpsock_plan9.go
index 1df293d..732a3b0 100644
--- a/libgo/go/net/udpsock_plan9.go
+++ b/libgo/go/net/udpsock_plan9.go
@@ -7,6 +7,7 @@ package net
import (
"context"
"errors"
+ "net/netip"
"os"
"syscall"
)
@@ -28,8 +29,27 @@ func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
return n, addr, nil
}
-func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
- return 0, 0, 0, nil, syscall.EPLAN9
+func (c *UDPConn) readFromAddrPort(b []byte) (int, netip.AddrPort, error) {
+ // TODO: optimize. The equivalent code on posix is alloc-free.
+ buf := make([]byte, udpHeaderSize+len(b))
+ m, err := c.fd.Read(buf)
+ if err != nil {
+ return 0, netip.AddrPort{}, err
+ }
+ if m < udpHeaderSize {
+ return 0, netip.AddrPort{}, errors.New("short read reading UDP header")
+ }
+ buf = buf[:m]
+
+ h, buf := unmarshalUDPHeader(buf)
+ n := copy(b, buf)
+ ip, _ := netip.AddrFromSlice(h.raddr)
+ addr := netip.AddrPortFrom(ip, h.rport)
+ return n, addr, nil
+}
+
+func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
+ return 0, 0, 0, netip.AddrPort{}, syscall.EPLAN9
}
func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) {
@@ -52,10 +72,18 @@ func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) {
return len(b), nil
}
+func (c *UDPConn) writeToAddrPort(b []byte, addr netip.AddrPort) (int, error) {
+ return c.writeTo(b, UDPAddrFromAddrPort(addr)) // TODO: optimize instead of allocating
+}
+
func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
return 0, 0, syscall.EPLAN9
}
+func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) {
+ return 0, 0, syscall.EPLAN9
+}
+
func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
fd, err := dialPlan9(ctx, sd.network, laddr, raddr)
if err != nil {
diff --git a/libgo/go/net/udpsock_posix.go b/libgo/go/net/udpsock_posix.go
index a4c6da2..a435658 100644
--- a/libgo/go/net/udpsock_posix.go
+++ b/libgo/go/net/udpsock_posix.go
@@ -3,12 +3,12 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows
package net
import (
"context"
+ "net/netip"
"syscall"
)
@@ -44,27 +44,68 @@ func (a *UDPAddr) toLocal(net string) sockaddr {
}
func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
- n, sa, err := c.fd.readFrom(b)
- switch sa := sa.(type) {
- case *syscall.SockaddrInet4:
- *addr = UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
- case *syscall.SockaddrInet6:
- *addr = UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))}
- default:
+ var n int
+ var err error
+ switch c.fd.family {
+ case syscall.AF_INET:
+ var from syscall.SockaddrInet4
+ n, err = c.fd.readFromInet4(b, &from)
+ if err == nil {
+ ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 4 bytes
+ *addr = UDPAddr{IP: ip[:], Port: from.Port}
+ }
+ case syscall.AF_INET6:
+ var from syscall.SockaddrInet6
+ n, err = c.fd.readFromInet6(b, &from)
+ if err == nil {
+ ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes
+ *addr = UDPAddr{IP: ip[:], Port: from.Port, Zone: zoneCache.name(int(from.ZoneId))}
+ }
+ }
+ if err != nil {
// No sockaddr, so don't return UDPAddr.
addr = nil
}
return n, addr, err
}
-func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
- var sa syscall.Sockaddr
- n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0)
- switch sa := sa.(type) {
- case *syscall.SockaddrInet4:
- addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
- case *syscall.SockaddrInet6:
- addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))}
+func (c *UDPConn) readFromAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
+ var ip netip.Addr
+ var port int
+ switch c.fd.family {
+ case syscall.AF_INET:
+ var from syscall.SockaddrInet4
+ n, err = c.fd.readFromInet4(b, &from)
+ if err == nil {
+ ip = netip.AddrFrom4(from.Addr)
+ port = from.Port
+ }
+ case syscall.AF_INET6:
+ var from syscall.SockaddrInet6
+ n, err = c.fd.readFromInet6(b, &from)
+ if err == nil {
+ ip = netip.AddrFrom16(from.Addr).WithZone(zoneCache.name(int(from.ZoneId)))
+ port = from.Port
+ }
+ }
+ if err == nil {
+ addr = netip.AddrPortFrom(ip, uint16(port))
+ }
+ return n, addr, err
+}
+
+func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
+ switch c.fd.family {
+ case syscall.AF_INET:
+ var sa syscall.SockaddrInet4
+ n, oobn, flags, err = c.fd.readMsgInet4(b, oob, 0, &sa)
+ ip := netip.AddrFrom4(sa.Addr)
+ addr = netip.AddrPortFrom(ip, uint16(sa.Port))
+ case syscall.AF_INET6:
+ var sa syscall.SockaddrInet6
+ n, oobn, flags, err = c.fd.readMsgInet6(b, oob, 0, &sa)
+ ip := netip.AddrFrom16(sa.Addr).WithZone(zoneCache.name(int(sa.ZoneId)))
+ addr = netip.AddrPortFrom(ip, uint16(sa.Port))
}
return
}
@@ -76,11 +117,49 @@ func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) {
if addr == nil {
return 0, errMissingAddress
}
- sa, err := addr.sockaddr(c.fd.family)
- if err != nil {
- return 0, err
+
+ switch c.fd.family {
+ case syscall.AF_INET:
+ sa, err := ipToSockaddrInet4(addr.IP, addr.Port)
+ if err != nil {
+ return 0, err
+ }
+ return c.fd.writeToInet4(b, &sa)
+ case syscall.AF_INET6:
+ sa, err := ipToSockaddrInet6(addr.IP, addr.Port, addr.Zone)
+ if err != nil {
+ return 0, err
+ }
+ return c.fd.writeToInet6(b, &sa)
+ default:
+ return 0, &AddrError{Err: "invalid address family", Addr: addr.IP.String()}
+ }
+}
+
+func (c *UDPConn) writeToAddrPort(b []byte, addr netip.AddrPort) (int, error) {
+ if c.fd.isConnected {
+ return 0, ErrWriteToConnected
+ }
+ if !addr.IsValid() {
+ return 0, errMissingAddress
+ }
+
+ switch c.fd.family {
+ case syscall.AF_INET:
+ sa, err := addrPortToSockaddrInet4(addr)
+ if err != nil {
+ return 0, err
+ }
+ return c.fd.writeToInet4(b, &sa)
+ case syscall.AF_INET6:
+ sa, err := addrPortToSockaddrInet6(addr)
+ if err != nil {
+ return 0, err
+ }
+ return c.fd.writeToInet6(b, &sa)
+ default:
+ return 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()}
}
- return c.fd.writeTo(b, sa)
}
func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
@@ -97,6 +176,32 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
return c.fd.writeMsg(b, oob, sa)
}
+func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) {
+ if c.fd.isConnected && addr.IsValid() {
+ return 0, 0, ErrWriteToConnected
+ }
+ if !c.fd.isConnected && !addr.IsValid() {
+ return 0, 0, errMissingAddress
+ }
+
+ switch c.fd.family {
+ case syscall.AF_INET:
+ sa, err := addrPortToSockaddrInet4(addr)
+ if err != nil {
+ return 0, 0, err
+ }
+ return c.fd.writeMsgInet4(b, oob, &sa)
+ case syscall.AF_INET6:
+ sa, err := addrPortToSockaddrInet6(addr)
+ if err != nil {
+ return 0, 0, err
+ }
+ return c.fd.writeMsgInet6(b, oob, &sa)
+ default:
+ return 0, 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()}
+ }
+}
+
func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control)
if err != nil {
diff --git a/libgo/go/net/udpsock_test.go b/libgo/go/net/udpsock_test.go
index 0e8c351..21f5af5 100644
--- a/libgo/go/net/udpsock_test.go
+++ b/libgo/go/net/udpsock_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -286,10 +285,7 @@ func TestIPv6LinkLocalUnicastUDP(t *testing.T) {
t.Log(err)
continue
}
- ls, err := (&packetListener{PacketConn: c1}).newLocalServer()
- if err != nil {
- t.Fatal(err)
- }
+ ls := (&packetListener{PacketConn: c1}).newLocalServer()
defer ls.teardown()
ch := make(chan error, 1)
handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, ch) }
@@ -334,10 +330,7 @@ func TestUDPZeroBytePayload(t *testing.T) {
testenv.SkipFlaky(t, 29225)
}
- c, err := newLocalPacketListener("udp")
- if err != nil {
- t.Fatal(err)
- }
+ c := newLocalPacketListener(t, "udp")
defer c.Close()
for _, genericRead := range []bool{false, true} {
@@ -370,10 +363,7 @@ func TestUDPZeroByteBuffer(t *testing.T) {
t.Skipf("not supported on %s", runtime.GOOS)
}
- c, err := newLocalPacketListener("udp")
- if err != nil {
- t.Fatal(err)
- }
+ c := newLocalPacketListener(t, "udp")
defer c.Close()
b := []byte("UDP ZERO BYTE BUFFER TEST")
@@ -407,10 +397,7 @@ func TestUDPReadSizeError(t *testing.T) {
t.Skipf("not supported on %s", runtime.GOOS)
}
- c1, err := newLocalPacketListener("udp")
- if err != nil {
- t.Fatal(err)
- }
+ c1 := newLocalPacketListener(t, "udp")
defer c1.Close()
c2, err := Dial("udp", c1.LocalAddr().String())
@@ -475,11 +462,100 @@ func TestUDPReadTimeout(t *testing.T) {
}
}
+func TestAllocs(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ // Plan9 wasn't optimized.
+ t.Skipf("skipping on %v", runtime.GOOS)
+ }
+ builder := os.Getenv("GO_BUILDER_NAME")
+ switch builder {
+ case "linux-amd64-noopt":
+ // Optimizations are required to remove the allocs.
+ t.Skipf("skipping on %v", builder)
+ }
+ conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ addr := conn.LocalAddr()
+ addrPort := addr.(*UDPAddr).AddrPort()
+ buf := make([]byte, 8)
+
+ allocs := testing.AllocsPerRun(1000, func() {
+ _, _, err := conn.WriteMsgUDPAddrPort(buf, nil, addrPort)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, _, _, err = conn.ReadMsgUDPAddrPort(buf, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ if got := int(allocs); got != 0 {
+ t.Errorf("WriteMsgUDPAddrPort/ReadMsgUDPAddrPort allocated %d objects", got)
+ }
+
+ allocs = testing.AllocsPerRun(1000, func() {
+ _, err := conn.WriteToUDPAddrPort(buf, addrPort)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, err = conn.ReadFromUDPAddrPort(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ if got := int(allocs); got != 0 {
+ t.Errorf("WriteToUDPAddrPort/ReadFromUDPAddrPort allocated %d objects", got)
+ }
+
+ allocs = testing.AllocsPerRun(1000, func() {
+ _, err := conn.WriteTo(buf, addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, err = conn.ReadFromUDP(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ if got := int(allocs); got != 1 {
+ if runtime.Compiler != "gccgo" {
+ t.Errorf("WriteTo/ReadFromUDP allocated %d objects", got)
+ }
+ }
+}
+
+func BenchmarkReadWriteMsgUDPAddrPort(b *testing.B) {
+ conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer conn.Close()
+ addr := conn.LocalAddr().(*UDPAddr).AddrPort()
+ buf := make([]byte, 8)
+ b.ResetTimer()
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ _, _, err := conn.WriteMsgUDPAddrPort(buf, nil, addr)
+ if err != nil {
+ b.Fatal(err)
+ }
+ _, _, _, _, err = conn.ReadMsgUDPAddrPort(buf, nil)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
func BenchmarkWriteToReadFromUDP(b *testing.B) {
conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
if err != nil {
b.Fatal(err)
}
+ defer conn.Close()
addr := conn.LocalAddr()
buf := make([]byte, 8)
b.ResetTimer()
@@ -495,3 +571,61 @@ func BenchmarkWriteToReadFromUDP(b *testing.B) {
}
}
}
+
+func BenchmarkWriteToReadFromUDPAddrPort(b *testing.B) {
+ conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer conn.Close()
+ addr := conn.LocalAddr().(*UDPAddr).AddrPort()
+ buf := make([]byte, 8)
+ b.ResetTimer()
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ _, err := conn.WriteToUDPAddrPort(buf, addr)
+ if err != nil {
+ b.Fatal(err)
+ }
+ _, _, err = conn.ReadFromUDPAddrPort(buf)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func TestUDPIPVersionReadMsg(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping on %v", runtime.GOOS)
+ }
+ conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ daddr := conn.LocalAddr().(*UDPAddr).AddrPort()
+ buf := make([]byte, 8)
+ _, err = conn.WriteToUDPAddrPort(buf, daddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, _, saddr, err := conn.ReadMsgUDPAddrPort(buf, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !saddr.Addr().Is4() {
+ t.Error("returned AddrPort is not IPv4")
+ }
+ _, err = conn.WriteToUDPAddrPort(buf, daddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, _, soldaddr, err := conn.ReadMsgUDP(buf, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(soldaddr.IP) != 4 {
+ t.Error("returned UDPAddr is not IPv4")
+ }
+}
diff --git a/libgo/go/net/unixsock_posix.go b/libgo/go/net/unixsock_posix.go
index af075af..927b533 100644
--- a/libgo/go/net/unixsock_posix.go
+++ b/libgo/go/net/unixsock_posix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows
-// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/unixsock_readmsg_cloexec.go b/libgo/go/net/unixsock_readmsg_cloexec.go
index 716484c..fa4fd7d 100644
--- a/libgo/go/net/unixsock_readmsg_cloexec.go
+++ b/libgo/go/net/unixsock_readmsg_cloexec.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || freebsd || solaris
-// +build aix darwin freebsd solaris
package net
diff --git a/libgo/go/net/unixsock_readmsg_cmsg_cloexec.go b/libgo/go/net/unixsock_readmsg_cmsg_cloexec.go
index bb851b8..6b0de87 100644
--- a/libgo/go/net/unixsock_readmsg_cmsg_cloexec.go
+++ b/libgo/go/net/unixsock_readmsg_cmsg_cloexec.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build dragonfly || linux || netbsd || openbsd
-// +build dragonfly linux netbsd openbsd
package net
diff --git a/libgo/go/net/unixsock_readmsg_other.go b/libgo/go/net/unixsock_readmsg_other.go
index 3290761..b3d19fe 100644
--- a/libgo/go/net/unixsock_readmsg_other.go
+++ b/libgo/go/net/unixsock_readmsg_other.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build (js && wasm) || windows
-// +build js,wasm windows
package net
diff --git a/libgo/go/net/unixsock_readmsg_test.go b/libgo/go/net/unixsock_readmsg_test.go
index a4d2fca..c3bfbf9 100644
--- a/libgo/go/net/unixsock_readmsg_test.go
+++ b/libgo/go/net/unixsock_readmsg_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/unixsock_test.go b/libgo/go/net/unixsock_test.go
index 71092e8..2fc9580 100644
--- a/libgo/go/net/unixsock_test.go
+++ b/libgo/go/net/unixsock_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js && !plan9 && !windows
-// +build !js,!plan9,!windows
package net
@@ -26,7 +25,7 @@ func TestReadUnixgramWithUnnamedSocket(t *testing.T) {
testenv.SkipFlaky(t, 15157)
}
- addr := testUnixAddr()
+ addr := testUnixAddr(t)
la, err := ResolveUnixAddr("unixgram", addr)
if err != nil {
t.Fatal(err)
@@ -77,10 +76,7 @@ func TestUnixgramZeroBytePayload(t *testing.T) {
t.Skip("unixgram test")
}
- c1, err := newLocalPacketListener("unixgram")
- if err != nil {
- t.Fatal(err)
- }
+ c1 := newLocalPacketListener(t, "unixgram")
defer os.Remove(c1.LocalAddr().String())
defer c1.Close()
@@ -127,10 +123,7 @@ func TestUnixgramZeroByteBuffer(t *testing.T) {
// issue 4352: Recvfrom failed with "address family not
// supported by protocol family" if zero-length buffer provided
- c1, err := newLocalPacketListener("unixgram")
- if err != nil {
- t.Fatal(err)
- }
+ c1 := newLocalPacketListener(t, "unixgram")
defer os.Remove(c1.LocalAddr().String())
defer c1.Close()
@@ -175,7 +168,7 @@ func TestUnixgramWrite(t *testing.T) {
t.Skip("unixgram test")
}
- addr := testUnixAddr()
+ addr := testUnixAddr(t)
laddr, err := ResolveUnixAddr("unixgram", addr)
if err != nil {
t.Fatal(err)
@@ -220,7 +213,7 @@ func testUnixgramWriteConn(t *testing.T, raddr *UnixAddr) {
}
func testUnixgramWritePacketConn(t *testing.T, raddr *UnixAddr) {
- addr := testUnixAddr()
+ addr := testUnixAddr(t)
c, err := ListenPacket("unixgram", addr)
if err != nil {
t.Fatal(err)
@@ -249,9 +242,9 @@ func TestUnixConnLocalAndRemoteNames(t *testing.T) {
}
handler := func(ls *localServer, ln Listener) {}
- for _, laddr := range []string{"", testUnixAddr()} {
+ for _, laddr := range []string{"", testUnixAddr(t)} {
laddr := laddr
- taddr := testUnixAddr()
+ taddr := testUnixAddr(t)
ta, err := ResolveUnixAddr("unix", taddr)
if err != nil {
t.Fatal(err)
@@ -260,10 +253,7 @@ func TestUnixConnLocalAndRemoteNames(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- ls, err := (&streamListener{Listener: ln}).newLocalServer()
- if err != nil {
- t.Fatal(err)
- }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
@@ -311,9 +301,9 @@ func TestUnixgramConnLocalAndRemoteNames(t *testing.T) {
t.Skip("unixgram test")
}
- for _, laddr := range []string{"", testUnixAddr()} {
+ for _, laddr := range []string{"", testUnixAddr(t)} {
laddr := laddr
- taddr := testUnixAddr()
+ taddr := testUnixAddr(t)
ta, err := ResolveUnixAddr("unixgram", taddr)
if err != nil {
t.Fatal(err)
@@ -369,7 +359,7 @@ func TestUnixUnlink(t *testing.T) {
if !testableNetwork("unix") {
t.Skip("unix test")
}
- name := testUnixAddr()
+ name := testUnixAddr(t)
listen := func(t *testing.T) *UnixListener {
l, err := Listen("unix", name)
diff --git a/libgo/go/net/unixsock_windows_test.go b/libgo/go/net/unixsock_windows_test.go
index 29244f6..d541d89 100644
--- a/libgo/go/net/unixsock_windows_test.go
+++ b/libgo/go/net/unixsock_windows_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build windows
-// +build windows
package net
@@ -46,9 +45,9 @@ func TestUnixConnLocalWindows(t *testing.T) {
}
handler := func(ls *localServer, ln Listener) {}
- for _, laddr := range []string{"", testUnixAddr()} {
+ for _, laddr := range []string{"", testUnixAddr(t)} {
laddr := laddr
- taddr := testUnixAddr()
+ taddr := testUnixAddr(t)
ta, err := ResolveUnixAddr("unix", taddr)
if err != nil {
t.Fatal(err)
@@ -57,10 +56,7 @@ func TestUnixConnLocalWindows(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- ls, err := (&streamListener{Listener: ln}).newLocalServer()
- if err != nil {
- t.Fatal(err)
- }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
defer ls.teardown()
if err := ls.buildup(handler); err != nil {
t.Fatal(err)
diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go
index 20de0f6..f31aa08 100644
--- a/libgo/go/net/url/url.go
+++ b/libgo/go/net/url/url.go
@@ -452,20 +452,6 @@ func getScheme(rawURL string) (scheme, path string, err error) {
return "", rawURL, nil
}
-// split slices s into two substrings separated by the first occurrence of
-// sep. If cutc is true then sep is excluded from the second substring.
-// If sep does not occur in s then s and the empty string is returned.
-func split(s string, sep byte, cutc bool) (string, string) {
- i := strings.IndexByte(s, sep)
- if i < 0 {
- return s, ""
- }
- if cutc {
- return s[:i], s[i+1:]
- }
- return s[:i], s[i:]
-}
-
// Parse parses a raw url into a URL structure.
//
// The url may be relative (a path, without a host) or absolute
@@ -474,7 +460,7 @@ func split(s string, sep byte, cutc bool) (string, string) {
// error, due to parsing ambiguities.
func Parse(rawURL string) (*URL, error) {
// Cut off #frag
- u, frag := split(rawURL, '#', true)
+ u, frag, _ := strings.Cut(rawURL, "#")
url, err := parse(u, false)
if err != nil {
return nil, &Error{"parse", u, err}
@@ -534,7 +520,7 @@ func parse(rawURL string, viaRequest bool) (*URL, error) {
url.ForceQuery = true
rest = rest[:len(rest)-1]
} else {
- rest, url.RawQuery = split(rest, '?', true)
+ rest, url.RawQuery, _ = strings.Cut(rest, "?")
}
if !strings.HasPrefix(rest, "/") {
@@ -553,9 +539,7 @@ func parse(rawURL string, viaRequest bool) (*URL, error) {
// RFC 3986, ยง3.3:
// In addition, a URI reference (Section 4.1) may be a relative-path reference,
// in which case the first path segment cannot contain a colon (":") character.
- colon := strings.Index(rest, ":")
- slash := strings.Index(rest, "/")
- if colon >= 0 && (slash < 0 || colon < slash) {
+ if segment, _, _ := strings.Cut(rest, "/"); strings.Contains(segment, ":") {
// First path segment has colon. Not allowed in relative URL.
return nil, errors.New("first path segment in URL cannot contain colon")
}
@@ -563,7 +547,10 @@ func parse(rawURL string, viaRequest bool) (*URL, error) {
if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") {
var authority string
- authority, rest = split(rest[2:], '/', false)
+ authority, rest = rest[2:], ""
+ if i := strings.Index(authority, "/"); i >= 0 {
+ authority, rest = authority[:i], authority[i:]
+ }
url.User, url.Host, err = parseAuthority(authority)
if err != nil {
return nil, err
@@ -602,7 +589,7 @@ func parseAuthority(authority string) (user *Userinfo, host string, err error) {
}
user = User(userinfo)
} else {
- username, password := split(userinfo, ':', true)
+ username, password, _ := strings.Cut(userinfo, ":")
if username, err = unescape(username, encodeUserPassword); err != nil {
return nil, "", err
}
@@ -840,7 +827,7 @@ func (u *URL) String() string {
// it would be mistaken for a scheme name. Such a segment must be
// preceded by a dot-segment (e.g., "./this:that") to make a relative-
// path reference.
- if i := strings.IndexByte(path, ':'); i > -1 && strings.IndexByte(path[:i], '/') == -1 {
+ if segment, _, _ := strings.Cut(path, "/"); strings.Contains(segment, ":") {
buf.WriteString("./")
}
}
@@ -933,12 +920,8 @@ func ParseQuery(query string) (Values, error) {
func parseQuery(m Values, query string) (err error) {
for query != "" {
- key := query
- if i := strings.IndexAny(key, "&"); i >= 0 {
- key, query = key[:i], key[i+1:]
- } else {
- query = ""
- }
+ var key string
+ key, query, _ = strings.Cut(query, "&")
if strings.Contains(key, ";") {
err = fmt.Errorf("invalid semicolon separator in query")
continue
@@ -946,10 +929,7 @@ func parseQuery(m Values, query string) (err error) {
if key == "" {
continue
}
- value := ""
- if i := strings.Index(key, "="); i >= 0 {
- key, value = key[:i], key[i+1:]
- }
+ key, value, _ := strings.Cut(key, "=")
key, err1 := QueryUnescape(key)
if err1 != nil {
if err == nil {
@@ -1013,22 +993,16 @@ func resolvePath(base, ref string) string {
}
var (
- last string
elem string
- i int
dst strings.Builder
)
first := true
remaining := full
// We want to return a leading '/', so write it now.
dst.WriteByte('/')
- for i >= 0 {
- i = strings.IndexByte(remaining, '/')
- if i < 0 {
- last, elem, remaining = remaining, remaining, ""
- } else {
- elem, remaining = remaining[:i], remaining[i+1:]
- }
+ found := true
+ for found {
+ elem, remaining, found = strings.Cut(remaining, "/")
if elem == "." {
first = false
// drop
@@ -1056,7 +1030,7 @@ func resolvePath(base, ref string) string {
}
}
- if last == "." || last == ".." {
+ if elem == "." || elem == ".." {
dst.WriteByte('/')
}
@@ -1109,7 +1083,7 @@ func (u *URL) ResolveReference(ref *URL) *URL {
url.Path = ""
return &url
}
- if ref.Path == "" && ref.RawQuery == "" {
+ if ref.Path == "" && !ref.ForceQuery && ref.RawQuery == "" {
url.RawQuery = u.RawQuery
if ref.Fragment == "" {
url.Fragment = u.Fragment
diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go
index 63c8e69..664757b 100644
--- a/libgo/go/net/url/url_test.go
+++ b/libgo/go/net/url/url_test.go
@@ -618,7 +618,7 @@ var urltests = []URLTest{
// more useful string for debugging than fmt's struct printer
func ufmt(u *URL) string {
- var user, pass interface{}
+ var user, pass any
if u.User != nil {
user = u.User.Username()
if p, ok := u.User.Password(); ok {
@@ -1172,7 +1172,7 @@ var resolveReferenceTests = []struct {
{"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot/"},
// Remove any dot-segments prior to forming the target URI.
- // http://tools.ietf.org/html/rfc3986#section-5.2.4
+ // https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4
{"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/baz"},
// Triple dot isn't special
@@ -1192,7 +1192,7 @@ var resolveReferenceTests = []struct {
{"http://foo.com/foo%2dbar/", "./baz-quux", "http://foo.com/foo%2dbar/baz-quux"},
// RFC 3986: Normal Examples
- // http://tools.ietf.org/html/rfc3986#section-5.4.1
+ // https://datatracker.ietf.org/doc/html/rfc3986#section-5.4.1
{"http://a/b/c/d;p?q", "g:h", "g:h"},
{"http://a/b/c/d;p?q", "g", "http://a/b/c/g"},
{"http://a/b/c/d;p?q", "./g", "http://a/b/c/g"},
@@ -1218,7 +1218,7 @@ var resolveReferenceTests = []struct {
{"http://a/b/c/d;p?q", "../../g", "http://a/g"},
// RFC 3986: Abnormal Examples
- // http://tools.ietf.org/html/rfc3986#section-5.4.2
+ // https://datatracker.ietf.org/doc/html/rfc3986#section-5.4.2
{"http://a/b/c/d;p?q", "../../../g", "http://a/g"},
{"http://a/b/c/d;p?q", "../../../../g", "http://a/g"},
{"http://a/b/c/d;p?q", "/./g", "http://a/g"},
@@ -1244,6 +1244,9 @@ var resolveReferenceTests = []struct {
{"https://a/b/c/d;p?q", "//g/d/e/f?y#s", "https://g/d/e/f?y#s"},
{"https://a/b/c/d;p#s", "?y", "https://a/b/c/d;p?y"},
{"https://a/b/c/d;p?q#s", "?y", "https://a/b/c/d;p?y"},
+
+ // Empty path and query but with ForceQuery (issue 46033).
+ {"https://a/b/c/d;p?q#s", "?", "https://a/b/c/d;p?"},
}
func TestResolveReference(t *testing.T) {
@@ -2059,12 +2062,3 @@ func BenchmarkPathUnescape(b *testing.B) {
})
}
}
-
-var sink string
-
-func BenchmarkSplit(b *testing.B) {
- url := "http://www.google.com/?q=go+language#foo%26bar"
- for i := 0; i < b.N; i++ {
- sink, sink = split(url, '#', true)
- }
-}
diff --git a/libgo/go/net/write_unix_test.go b/libgo/go/net/write_unix_test.go
index f79f2d0..23e8bef 100644
--- a/libgo/go/net/write_unix_test.go
+++ b/libgo/go/net/write_unix_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
-// +build darwin dragonfly freebsd linux netbsd openbsd solaris
package net
diff --git a/libgo/go/net/writev_test.go b/libgo/go/net/writev_test.go
index bf40ca2..18795a4 100644
--- a/libgo/go/net/writev_test.go
+++ b/libgo/go/net/writev_test.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build !js
-// +build !js
package net
@@ -187,10 +186,7 @@ func TestWritevError(t *testing.T) {
t.Skipf("skipping the test: windows does not have problem sending large chunks of data")
}
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
+ ln := newLocalListener(t, "tcp")
defer ln.Close()
ch := make(chan Conn, 1)
diff --git a/libgo/go/net/writev_unix.go b/libgo/go/net/writev_unix.go
index a0fedc2..51ab29d 100644
--- a/libgo/go/net/writev_unix.go
+++ b/libgo/go/net/writev_unix.go
@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || illumos || linux || netbsd || openbsd
-// +build darwin dragonfly freebsd illumos linux netbsd openbsd
package net