aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/crypto/tls/tls_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/crypto/tls/tls_test.go')
-rw-r--r--libgo/go/crypto/tls/tls_test.go174
1 files changed, 155 insertions, 19 deletions
diff --git a/libgo/go/crypto/tls/tls_test.go b/libgo/go/crypto/tls/tls_test.go
index 178b519..1984234 100644
--- a/libgo/go/crypto/tls/tls_test.go
+++ b/libgo/go/crypto/tls/tls_test.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/x509"
"encoding/json"
@@ -201,6 +202,118 @@ func TestDialTimeout(t *testing.T) {
}
}
+func TestDeadlineOnWrite(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ srvCh := make(chan *Conn, 1)
+
+ go func() {
+ sconn, err := ln.Accept()
+ if err != nil {
+ srvCh <- nil
+ return
+ }
+ srv := Server(sconn, testConfig.Clone())
+ if err := srv.Handshake(); err != nil {
+ srvCh <- nil
+ return
+ }
+ srvCh <- srv
+ }()
+
+ clientConfig := testConfig.Clone()
+ clientConfig.MaxVersion = VersionTLS12
+ conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ srv := <-srvCh
+ if srv == nil {
+ t.Error(err)
+ }
+
+ // Make sure the client/server is setup correctly and is able to do a typical Write/Read
+ buf := make([]byte, 6)
+ if _, err := srv.Write([]byte("foobar")); err != nil {
+ t.Errorf("Write err: %v", err)
+ }
+ if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
+ t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
+ }
+
+ // Set a deadline which should cause Write to timeout
+ if err = srv.SetDeadline(time.Now()); err != nil {
+ t.Fatalf("SetDeadline(time.Now()) err: %v", err)
+ }
+ if _, err = srv.Write([]byte("should fail")); err == nil {
+ t.Fatal("Write should have timed out")
+ }
+
+ // Clear deadline and make sure it still times out
+ if err = srv.SetDeadline(time.Time{}); err != nil {
+ t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
+ }
+ if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
+ t.Fatal("Write which previously failed should still time out")
+ }
+
+ // Verify the error
+ if ne := err.(net.Error); ne.Temporary() != false {
+ t.Error("Write timed out but incorrectly classified the error as Temporary")
+ }
+ if !isTimeoutError(err) {
+ t.Error("Write timed out but did not classify the error as a Timeout")
+ }
+}
+
+type readerFunc func([]byte) (int, error)
+
+func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
+
+// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
+// (The other cases are all handled by the existing dial tests in this package, which
+// all also flow through the same code shared code paths)
+func TestDialer(t *testing.T) {
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ unblockServer := make(chan struct{}) // close-only
+ defer close(unblockServer)
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+ <-unblockServer
+ }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ d := Dialer{Config: &Config{
+ Rand: readerFunc(func(b []byte) (n int, err error) {
+ // By the time crypto/tls wants randomness, that means it has a TCP
+ // connection, so we're past the Dialer's dial and now blocked
+ // in a handshake. Cancel our context and see if we get unstuck.
+ // (Our TCP listener above never reads or writes, so the Handshake
+ // would otherwise be stuck forever)
+ cancel()
+ return len(b), nil
+ }),
+ ServerName: "foo",
+ }}
+ _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+ if err != context.Canceled {
+ t.Errorf("err = %v; want context.Canceled", err)
+ }
+}
+
func isTimeoutError(err error) bool {
if ne, ok := err.(net.Error); ok {
return ne.Timeout()
@@ -294,7 +407,11 @@ func TestTLSUniqueMatches(t *testing.T) {
defer ln.Close()
serverTLSUniques := make(chan []byte)
+ parentDone := make(chan struct{})
+ childDone := make(chan struct{})
+ defer close(parentDone)
go func() {
+ defer close(childDone)
for i := 0; i < 2; i++ {
sconn, err := ln.Accept()
if err != nil {
@@ -308,7 +425,11 @@ func TestTLSUniqueMatches(t *testing.T) {
t.Error(err)
return
}
- serverTLSUniques <- srv.ConnectionState().TLSUnique
+ select {
+ case <-parentDone:
+ return
+ case serverTLSUniques <- srv.ConnectionState().TLSUnique:
+ }
}
}()
@@ -318,7 +439,15 @@ func TestTLSUniqueMatches(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
+
+ var serverTLSUniquesValue []byte
+ select {
+ case <-childDone:
+ return
+ case serverTLSUniquesValue = <-serverTLSUniques:
+ }
+
+ if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
t.Error("client and server channel bindings differ")
}
conn.Close()
@@ -331,7 +460,14 @@ func TestTLSUniqueMatches(t *testing.T) {
if !conn.ConnectionState().DidResume {
t.Error("second session did not use resumption")
}
- if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) {
+
+ select {
+ case <-childDone:
+ return
+ case serverTLSUniquesValue = <-serverTLSUniques:
+ }
+
+ if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
t.Error("client and server channel bindings differ when session resumption is used")
}
}
@@ -598,7 +734,7 @@ func TestWarningAlertFlood(t *testing.T) {
}
func TestCloneFuncFields(t *testing.T) {
- const expectedCount = 5
+ const expectedCount = 6
called := 0
c1 := Config{
@@ -622,6 +758,10 @@ func TestCloneFuncFields(t *testing.T) {
called |= 1 << 4
return nil
},
+ VerifyConnection: func(ConnectionState) error {
+ called |= 1 << 5
+ return nil
+ },
}
c2 := c1.Clone()
@@ -631,6 +771,7 @@ func TestCloneFuncFields(t *testing.T) {
c2.GetClientCertificate(nil)
c2.GetConfigForClient(nil)
c2.VerifyPeerCertificate(nil, nil)
+ c2.VerifyConnection(ConnectionState{})
if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
@@ -644,17 +785,12 @@ func TestCloneNonFuncFields(t *testing.T) {
typ := v.Type()
for i := 0; i < typ.NumField(); i++ {
f := v.Field(i)
- if !f.CanSet() {
- // unexported field; not cloned.
- continue
- }
-
// testing/quick can't handle functions or interfaces and so
// isn't used here.
switch fn := typ.Field(i).Name; fn {
case "Rand":
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
- case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
+ case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate":
// DeepEqual can't compare functions. If you add a
// function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is
@@ -689,17 +825,17 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
case "Renegotiation":
f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
+ case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
+ continue // these are unexported fields that are handled separately
default:
t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
}
}
+ // Set the unexported fields related to session ticket keys, which are copied with Clone().
+ c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
+ c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
c2 := c1.Clone()
- // DeepEqual also compares unexported fields, thus c2 needs to have run
- // serverInit in order to be DeepEqual to c1. Cloning it and discarding
- // the result is sufficient.
- c2.Clone()
-
if !reflect.DeepEqual(&c1, c2) {
t.Errorf("clone failed to copy a field")
}
@@ -980,8 +1116,8 @@ func TestConnectionState(t *testing.T) {
if ss.ServerName != serverName {
t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
}
- if cs.ServerName != "" {
- t.Errorf("Got unexpected server name on the client side")
+ if cs.ServerName != serverName {
+ t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
}
if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
@@ -1307,7 +1443,7 @@ func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts
}
// TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
-// always makes PKCS#1 v1.5 signatures, so can't be used with RSA-PSS.
+// always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS.
func TestPKCS1OnlyCert(t *testing.T) {
clientConfig := testConfig.Clone()
clientConfig.Certificates = []Certificate{{
@@ -1315,7 +1451,7 @@ func TestPKCS1OnlyCert(t *testing.T) {
PrivateKey: brokenSigner{testRSAPrivateKey},
}}
serverConfig := testConfig.Clone()
- serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS#1 v1.5
+ serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5
serverConfig.ClientAuth = RequireAnyClientCert
// If RSA-PSS is selected, the handshake should fail.