// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris

package x509

import (
	"bytes"
	"fmt"
	"os"
	"path/filepath"
	"reflect"
	"strings"
	"testing"
)

const (
	testDir     = "testdata"
	testDirCN   = "test-dir"
	testFile    = "test-file.crt"
	testFileCN  = "test-file"
	testMissing = "missing"
)

func TestEnvVars(t *testing.T) {
	testCases := []struct {
		name    string
		fileEnv string
		dirEnv  string
		files   []string
		dirs    []string
		cns     []string
	}{
		{
			// Environment variables override the default locations preventing fall through.
			name:    "override-defaults",
			fileEnv: testMissing,
			dirEnv:  testMissing,
			files:   []string{testFile},
			dirs:    []string{testDir},
			cns:     nil,
		},
		{
			// File environment overrides default file locations.
			name:    "file",
			fileEnv: testFile,
			dirEnv:  "",
			files:   nil,
			dirs:    nil,
			cns:     []string{testFileCN},
		},
		{
			// Directory environment overrides default directory locations.
			name:    "dir",
			fileEnv: "",
			dirEnv:  testDir,
			files:   nil,
			dirs:    nil,
			cns:     []string{testDirCN},
		},
		{
			// File & directory environment overrides both default locations.
			name:    "file+dir",
			fileEnv: testFile,
			dirEnv:  testDir,
			files:   nil,
			dirs:    nil,
			cns:     []string{testFileCN, testDirCN},
		},
		{
			// Environment variable empty / unset uses default locations.
			name:    "empty-fall-through",
			fileEnv: "",
			dirEnv:  "",
			files:   []string{testFile},
			dirs:    []string{testDir},
			cns:     []string{testFileCN, testDirCN},
		},
	}

	// Save old settings so we can restore before the test ends.
	origCertFiles, origCertDirectories := certFiles, certDirectories
	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
	defer func() {
		certFiles = origCertFiles
		certDirectories = origCertDirectories
		os.Setenv(certFileEnv, origFile)
		os.Setenv(certDirEnv, origDir)
	}()

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil {
				t.Fatalf("setenv %q failed: %v", certFileEnv, err)
			}
			if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil {
				t.Fatalf("setenv %q failed: %v", certDirEnv, err)
			}

			certFiles, certDirectories = tc.files, tc.dirs

			r, err := loadSystemRoots()
			if err != nil {
				t.Fatal("unexpected failure:", err)
			}

			if r == nil {
				t.Fatal("nil roots")
			}

			// Verify that the returned certs match, otherwise report where the mismatch is.
			for i, cn := range tc.cns {
				if i >= r.len() {
					t.Errorf("missing cert %v @ %v", cn, i)
				} else if r.mustCert(t, i).Subject.CommonName != cn {
					fmt.Printf("%#v\n", r.mustCert(t, 0).Subject)
					t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn)
				}
			}
			if r.len() > len(tc.cns) {
				t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns))
			}
		})
	}
}

// Ensure that "SSL_CERT_DIR" when used as the environment
// variable delimited by colons, allows loadSystemRoots to
// load all the roots from the respective directories.
// See https://golang.org/issue/35325.
func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) {
	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
	origCertFiles := certFiles[:]

	// To prevent any other certs from being loaded in
	// through "SSL_CERT_FILE" or from known "certFiles",
	// clear them all, and they'll be reverting on defer.
	certFiles = certFiles[:0]
	os.Setenv(certFileEnv, "")

	defer func() {
		certFiles = origCertFiles[:]
		os.Setenv(certDirEnv, origDir)
		os.Setenv(certFileEnv, origFile)
	}()

	tmpDir := t.TempDir()

	rootPEMs := []string{
		geoTrustRoot,
		googleLeaf,
		startComRoot,
	}

	var certDirs []string
	for i, certPEM := range rootPEMs {
		certDir := filepath.Join(tmpDir, fmt.Sprintf("cert-%d", i))
		if err := os.MkdirAll(certDir, 0755); err != nil {
			t.Fatalf("Failed to create certificate dir: %v", err)
		}
		certOutFile := filepath.Join(certDir, "cert.crt")
		if err := os.WriteFile(certOutFile, []byte(certPEM), 0655); err != nil {
			t.Fatalf("Failed to write certificate to file: %v", err)
		}
		certDirs = append(certDirs, certDir)
	}

	// Sanity check: the number of certDirs should be equal to the number of roots.
	if g, w := len(certDirs), len(rootPEMs); g != w {
		t.Fatalf("Failed sanity check: len(certsDir)=%d is not equal to len(rootsPEMS)=%d", g, w)
	}

	// Now finally concatenate them with a colon.
	colonConcatCertDirs := strings.Join(certDirs, ":")
	os.Setenv(certDirEnv, colonConcatCertDirs)
	gotPool, err := loadSystemRoots()
	if err != nil {
		t.Fatalf("Failed to load system roots: %v", err)
	}
	subjects := gotPool.Subjects()
	// We expect exactly len(rootPEMs) subjects back.
	if g, w := len(subjects), len(rootPEMs); g != w {
		t.Fatalf("Invalid number of subjects: got %d want %d", g, w)
	}

	wantPool := NewCertPool()
	for _, certPEM := range rootPEMs {
		wantPool.AppendCertsFromPEM([]byte(certPEM))
	}
	strCertPool := func(p *CertPool) string {
		return string(bytes.Join(p.Subjects(), []byte("\n")))
	}

	if !certPoolEqual(gotPool, wantPool) {
		g, w := strCertPool(gotPool), strCertPool(wantPool)
		t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w)
	}
}

func TestReadUniqueDirectoryEntries(t *testing.T) {
	tmp := t.TempDir()
	temp := func(base string) string { return filepath.Join(tmp, base) }
	if f, err := os.Create(temp("file")); err != nil {
		t.Fatal(err)
	} else {
		f.Close()
	}
	if err := os.Symlink("target-in", temp("link-in")); err != nil {
		t.Fatal(err)
	}
	if err := os.Symlink("../target-out", temp("link-out")); err != nil {
		t.Fatal(err)
	}
	got, err := readUniqueDirectoryEntries(tmp)
	if err != nil {
		t.Fatal(err)
	}
	gotNames := []string{}
	for _, fi := range got {
		gotNames = append(gotNames, fi.Name())
	}
	wantNames := []string{"file", "link-out"}
	if !reflect.DeepEqual(gotNames, wantNames) {
		t.Errorf("got %q; want %q", gotNames, wantNames)
	}
}