diff options
Diffstat (limited to 'libgo/go/crypto/tls/tls_test.go')
-rw-r--r-- | libgo/go/crypto/tls/tls_test.go | 174 |
1 files changed, 155 insertions, 19 deletions
diff --git a/libgo/go/crypto/tls/tls_test.go b/libgo/go/crypto/tls/tls_test.go index 178b519..1984234 100644 --- a/libgo/go/crypto/tls/tls_test.go +++ b/libgo/go/crypto/tls/tls_test.go @@ -6,6 +6,7 @@ package tls import ( "bytes" + "context" "crypto" "crypto/x509" "encoding/json" @@ -201,6 +202,118 @@ func TestDialTimeout(t *testing.T) { } } +func TestDeadlineOnWrite(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + ln := newLocalListener(t) + defer ln.Close() + + srvCh := make(chan *Conn, 1) + + go func() { + sconn, err := ln.Accept() + if err != nil { + srvCh <- nil + return + } + srv := Server(sconn, testConfig.Clone()) + if err := srv.Handshake(); err != nil { + srvCh <- nil + return + } + srvCh <- srv + }() + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = VersionTLS12 + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + srv := <-srvCh + if srv == nil { + t.Error(err) + } + + // Make sure the client/server is setup correctly and is able to do a typical Write/Read + buf := make([]byte, 6) + if _, err := srv.Write([]byte("foobar")); err != nil { + t.Errorf("Write err: %v", err) + } + if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" { + t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) + } + + // Set a deadline which should cause Write to timeout + if err = srv.SetDeadline(time.Now()); err != nil { + t.Fatalf("SetDeadline(time.Now()) err: %v", err) + } + if _, err = srv.Write([]byte("should fail")); err == nil { + t.Fatal("Write should have timed out") + } + + // Clear deadline and make sure it still times out + if err = srv.SetDeadline(time.Time{}); err != nil { + t.Fatalf("SetDeadline(time.Time{}) err: %v", err) + } + if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil { + t.Fatal("Write which previously failed should still time out") + } + + // Verify the error + if ne := err.(net.Error); ne.Temporary() != false { + t.Error("Write timed out but incorrectly classified the error as Temporary") + } + if !isTimeoutError(err) { + t.Error("Write timed out but did not classify the error as a Timeout") + } +} + +type readerFunc func([]byte) (int, error) + +func (f readerFunc) Read(b []byte) (int, error) { return f(b) } + +// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake. +// (The other cases are all handled by the existing dial tests in this package, which +// all also flow through the same code shared code paths) +func TestDialer(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + unblockServer := make(chan struct{}) // close-only + defer close(unblockServer) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + <-unblockServer + }() + + ctx, cancel := context.WithCancel(context.Background()) + d := Dialer{Config: &Config{ + Rand: readerFunc(func(b []byte) (n int, err error) { + // By the time crypto/tls wants randomness, that means it has a TCP + // connection, so we're past the Dialer's dial and now blocked + // in a handshake. Cancel our context and see if we get unstuck. + // (Our TCP listener above never reads or writes, so the Handshake + // would otherwise be stuck forever) + cancel() + return len(b), nil + }), + ServerName: "foo", + }} + _, err := d.DialContext(ctx, "tcp", ln.Addr().String()) + if err != context.Canceled { + t.Errorf("err = %v; want context.Canceled", err) + } +} + func isTimeoutError(err error) bool { if ne, ok := err.(net.Error); ok { return ne.Timeout() @@ -294,7 +407,11 @@ func TestTLSUniqueMatches(t *testing.T) { defer ln.Close() serverTLSUniques := make(chan []byte) + parentDone := make(chan struct{}) + childDone := make(chan struct{}) + defer close(parentDone) go func() { + defer close(childDone) for i := 0; i < 2; i++ { sconn, err := ln.Accept() if err != nil { @@ -308,7 +425,11 @@ func TestTLSUniqueMatches(t *testing.T) { t.Error(err) return } - serverTLSUniques <- srv.ConnectionState().TLSUnique + select { + case <-parentDone: + return + case serverTLSUniques <- srv.ConnectionState().TLSUnique: + } } }() @@ -318,7 +439,15 @@ func TestTLSUniqueMatches(t *testing.T) { if err != nil { t.Fatal(err) } - if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) { + + var serverTLSUniquesValue []byte + select { + case <-childDone: + return + case serverTLSUniquesValue = <-serverTLSUniques: + } + + if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { t.Error("client and server channel bindings differ") } conn.Close() @@ -331,7 +460,14 @@ func TestTLSUniqueMatches(t *testing.T) { if !conn.ConnectionState().DidResume { t.Error("second session did not use resumption") } - if !bytes.Equal(conn.ConnectionState().TLSUnique, <-serverTLSUniques) { + + select { + case <-childDone: + return + case serverTLSUniquesValue = <-serverTLSUniques: + } + + if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) { t.Error("client and server channel bindings differ when session resumption is used") } } @@ -598,7 +734,7 @@ func TestWarningAlertFlood(t *testing.T) { } func TestCloneFuncFields(t *testing.T) { - const expectedCount = 5 + const expectedCount = 6 called := 0 c1 := Config{ @@ -622,6 +758,10 @@ func TestCloneFuncFields(t *testing.T) { called |= 1 << 4 return nil }, + VerifyConnection: func(ConnectionState) error { + called |= 1 << 5 + return nil + }, } c2 := c1.Clone() @@ -631,6 +771,7 @@ func TestCloneFuncFields(t *testing.T) { c2.GetClientCertificate(nil) c2.GetConfigForClient(nil) c2.VerifyPeerCertificate(nil, nil) + c2.VerifyConnection(ConnectionState{}) if called != (1<<expectedCount)-1 { t.Fatalf("expected %d calls but saw calls %b", expectedCount, called) @@ -644,17 +785,12 @@ func TestCloneNonFuncFields(t *testing.T) { typ := v.Type() for i := 0; i < typ.NumField(); i++ { f := v.Field(i) - if !f.CanSet() { - // unexported field; not cloned. - continue - } - // testing/quick can't handle functions or interfaces and so // isn't used here. switch fn := typ.Field(i).Name; fn { case "Rand": f.Set(reflect.ValueOf(io.Reader(os.Stdin))) - case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate": + case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate": // DeepEqual can't compare functions. If you add a // function field to this list, you must also change // TestCloneFuncFields to ensure that the func field is @@ -689,17 +825,17 @@ func TestCloneNonFuncFields(t *testing.T) { f.Set(reflect.ValueOf([]CurveID{CurveP256})) case "Renegotiation": f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) + case "mutex", "autoSessionTicketKeys", "sessionTicketKeys": + continue // these are unexported fields that are handled separately default: t.Errorf("all fields must be accounted for, but saw unknown field %q", fn) } } + // Set the unexported fields related to session ticket keys, which are copied with Clone(). + c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)} + c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)} c2 := c1.Clone() - // DeepEqual also compares unexported fields, thus c2 needs to have run - // serverInit in order to be DeepEqual to c1. Cloning it and discarding - // the result is sufficient. - c2.Clone() - if !reflect.DeepEqual(&c1, c2) { t.Errorf("clone failed to copy a field") } @@ -980,8 +1116,8 @@ func TestConnectionState(t *testing.T) { if ss.ServerName != serverName { t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName) } - if cs.ServerName != "" { - t.Errorf("Got unexpected server name on the client side") + if cs.ServerName != serverName { + t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName) } if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 { @@ -1307,7 +1443,7 @@ func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts } // TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that -// always makes PKCS#1 v1.5 signatures, so can't be used with RSA-PSS. +// always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS. func TestPKCS1OnlyCert(t *testing.T) { clientConfig := testConfig.Clone() clientConfig.Certificates = []Certificate{{ @@ -1315,7 +1451,7 @@ func TestPKCS1OnlyCert(t *testing.T) { PrivateKey: brokenSigner{testRSAPrivateKey}, }} serverConfig := testConfig.Clone() - serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS#1 v1.5 + serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5 serverConfig.ClientAuth = RequireAnyClientCert // If RSA-PSS is selected, the handshake should fail. |