Explorar o código

refactor: simplify test asserts (#1271)

Oleksandr Redko %!s(int64=2) %!d(string=hai) anos
pai
achega
2959fc072c

+ 4 - 13
config/config_test.go

@@ -7,6 +7,7 @@ import (
 
 	"github.com/spf13/viper"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 const configPath = "../testdata/config/"
@@ -126,23 +127,13 @@ func TestTranslate(t *testing.T) {
 		viper.SetConfigName(tt.cfgName)
 		viper.SetConfigType("toml")
 		err := viper.ReadInConfig()
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 
 		var vc ViperConfig
 		err = viper.Unmarshal(&vc)
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 		cfg, err := vc.Translate()
-		if tt.wantError != nil {
-			if err == nil {
-				t.Errorf("expected error")
-			}
-			assert.Equal(t, tt.wantError, err)
-		}
-
+		assert.Equal(t, tt.wantError, err)
 		assert.Equal(t, cfg.Rules, tt.cfg.Rules)
 	}
 }

+ 3 - 2
detect/baseline_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+
 	"github.com/zricethezav/gitleaks/v8/report"
 )
 
@@ -82,7 +83,7 @@ func TestFileLoadBaseline(t *testing.T) {
 
 	for _, test := range tests {
 		_, err := LoadBaseline(test.Filename)
-		assert.Equal(t, test.ExpectedError.Error(), err.Error())
+		assert.Equal(t, test.ExpectedError, err)
 	}
 }
 
@@ -132,6 +133,6 @@ func TestIgnoreIssuesInBaseline(t *testing.T) {
 		for _, finding := range test.findings {
 			d.addFinding(finding)
 		}
-		assert.Equal(t, test.expectCount, len(d.findings))
+		assert.Len(t, d.findings, test.expectCount)
 	}
 }

+ 37 - 98
detect/detect_test.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/spf13/viper"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"github.com/zricethezav/gitleaks/v8/config"
 	"github.com/zricethezav/gitleaks/v8/report"
@@ -336,23 +337,14 @@ func TestDetect(t *testing.T) {
 		viper.SetConfigName(tt.cfgName)
 		viper.SetConfigType("toml")
 		err := viper.ReadInConfig()
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 
 		var vc config.ViperConfig
 		err = viper.Unmarshal(&vc)
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 		cfg, err := vc.Translate()
 		cfg.Path = filepath.Join(configPath, tt.cfgName+".toml")
-		if tt.wantError != nil {
-			if err == nil {
-				t.Errorf("expected error")
-			}
-			assert.Equal(t, tt.wantError, err)
-		}
+		assert.Equal(t, tt.wantError, err)
 		d := NewDetector(cfg)
 		d.baselinePath = tt.baselinePath
 
@@ -444,56 +436,38 @@ func TestFromGit(t *testing.T) {
 		},
 	}
 
-	err := moveDotGit("dotGit", ".git")
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer func() {
-		if err := moveDotGit(".git", "dotGit"); err != nil {
-			t.Error(err)
-		}
-	}()
+	moveDotGit(t, "dotGit", ".git")
+	defer moveDotGit(t, ".git", "dotGit")
 
 	for _, tt := range tests {
 
 		viper.AddConfigPath(configPath)
 		viper.SetConfigName("simple")
 		viper.SetConfigType("toml")
-		err = viper.ReadInConfig()
-		if err != nil {
-			t.Error(err)
-		}
+		err := viper.ReadInConfig()
+		require.NoError(t, err)
 
 		var vc config.ViperConfig
 		err = viper.Unmarshal(&vc)
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 		cfg, err := vc.Translate()
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 		detector := NewDetector(cfg)
 
 		var ignorePath string
 		info, err := os.Stat(tt.source)
-		if err != nil {
-			t.Fatalf("could not os.Stat: %v", err)
-		}
+		require.NoError(t, err)
 
 		if info.IsDir() {
 			ignorePath = filepath.Join(tt.source, ".gitleaksignore")
 		} else {
 			ignorePath = filepath.Join(filepath.Dir(tt.source), ".gitleaksignore")
 		}
-		if err = detector.AddGitleaksIgnore(ignorePath); err != nil {
-			t.Fatalf("could not call AddGitleaksIgnore: %v", err)
-		}
+		err = detector.AddGitleaksIgnore(ignorePath)
+		require.NoError(t, err)
 
 		findings, err := detector.DetectGit(tt.source, tt.logOpts, DetectType)
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 
 		for _, f := range findings {
 			f.Match = "" // remove lines cause copying and pasting them has some wack formatting
@@ -540,43 +514,27 @@ func TestFromGitStaged(t *testing.T) {
 		},
 	}
 
-	err := moveDotGit("dotGit", ".git")
-	if err != nil {
-		t.Fatal(err)
-	}
-	defer func() {
-		if err := moveDotGit(".git", "dotGit"); err != nil {
-			t.Error(err)
-		}
-	}()
+	moveDotGit(t, "dotGit", ".git")
+	defer moveDotGit(t, ".git", "dotGit")
 
 	for _, tt := range tests {
 
 		viper.AddConfigPath(configPath)
 		viper.SetConfigName("simple")
 		viper.SetConfigType("toml")
-		err = viper.ReadInConfig()
-		if err != nil {
-			t.Error(err)
-		}
+		err := viper.ReadInConfig()
+		require.NoError(t, err)
 
 		var vc config.ViperConfig
 		err = viper.Unmarshal(&vc)
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 		cfg, err := vc.Translate()
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 		detector := NewDetector(cfg)
-		if err = detector.AddGitleaksIgnore(filepath.Join(tt.source, ".gitleaksignore")); err != nil {
-			t.Fatalf("could not call AddGitleaksIgnore: %v", err)
-		}
+		err = detector.AddGitleaksIgnore(filepath.Join(tt.source, ".gitleaksignore"))
+		require.NoError(t, err)
 		findings, err := detector.DetectGit(tt.source, tt.logOpts, ProtectStagedType)
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 
 		for _, f := range findings {
 			f.Match = "" // remove lines cause copying and pasting them has some wack formatting
@@ -647,38 +605,28 @@ func TestFromFiles(t *testing.T) {
 		viper.SetConfigName("simple")
 		viper.SetConfigType("toml")
 		err := viper.ReadInConfig()
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 
 		var vc config.ViperConfig
 		err = viper.Unmarshal(&vc)
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 		cfg, _ := vc.Translate()
 		detector := NewDetector(cfg)
 
 		var ignorePath string
 		info, err := os.Stat(tt.source)
-		if err != nil {
-			t.Fatalf("could not call os.Stat: %v", err)
-		}
+		require.NoError(t, err)
 
 		if info.IsDir() {
 			ignorePath = filepath.Join(tt.source, ".gitleaksignore")
 		} else {
 			ignorePath = filepath.Join(filepath.Dir(tt.source), ".gitleaksignore")
 		}
-		if err = detector.AddGitleaksIgnore(ignorePath); err != nil {
-			t.Fatalf("could not call AddGitleaksIgnore: %v", err)
-		}
+		err = detector.AddGitleaksIgnore(ignorePath)
+		require.NoError(t, err)
 		detector.FollowSymlinks = true
 		findings, err := detector.DetectFiles(tt.source)
-		if err != nil {
-			t.Error(err)
-		}
-
+		require.NoError(t, err)
 		assert.ElementsMatch(t, tt.expectedFindings, findings)
 	}
 }
@@ -718,31 +666,25 @@ func TestDetectWithSymlinks(t *testing.T) {
 		viper.SetConfigName("simple")
 		viper.SetConfigType("toml")
 		err := viper.ReadInConfig()
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 
 		var vc config.ViperConfig
 		err = viper.Unmarshal(&vc)
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 		cfg, _ := vc.Translate()
 		detector := NewDetector(cfg)
 		detector.FollowSymlinks = true
 		findings, err := detector.DetectFiles(tt.source)
-		if err != nil {
-			t.Error(err)
-		}
+		require.NoError(t, err)
 		assert.ElementsMatch(t, tt.expectedFindings, findings)
 	}
 }
 
-func moveDotGit(from, to string) error {
+func moveDotGit(t *testing.T, from, to string) {
+	t.Helper()
+
 	repoDirs, err := os.ReadDir("../testdata/repos")
-	if err != nil {
-		return err
-	}
+	require.NoError(t, err)
 	for _, dir := range repoDirs {
 		if to == ".git" {
 			_, err := os.Stat(fmt.Sprintf("%s/%s/%s", repoBasePath, dir.Name(), "dotGit"))
@@ -762,9 +704,6 @@ func moveDotGit(from, to string) error {
 
 		err = os.Rename(fmt.Sprintf("%s/%s/%s", repoBasePath, dir.Name(), from),
 			fmt.Sprintf("%s/%s/%s", repoBasePath, dir.Name(), to))
-		if err != nil {
-			return err
-		}
+		require.NoError(t, err)
 	}
-	return nil
 }

+ 3 - 6
detect/location_test.go

@@ -2,6 +2,8 @@ package detect
 
 import (
 	"testing"
+
+	"github.com/stretchr/testify/assert"
 )
 
 // TestGetLocation tests the getLocation function.
@@ -50,11 +52,6 @@ func TestGetLocation(t *testing.T) {
 
 	for _, test := range tests {
 		loc := location(Fragment{newlineIndices: test.linePairs}, []int{test.start, test.end})
-		if loc != test.wantLocation {
-			t.Errorf("\nstartLine %d\nstartColumn: %d\nendLine: %d\nendColumn: %d\nstartLineIndex: %d\nendlineIndex %d",
-				loc.startLine, loc.startColumn, loc.endLine, loc.endColumn, loc.startLineIndex, loc.endLineIndex)
-
-			t.Error("got", loc, "want", test.wantLocation)
-		}
+		assert.Equal(t, test.wantLocation, loc)
 	}
 }

+ 20 - 32
report/csv_test.go

@@ -1,11 +1,12 @@
 package report
 
 import (
-	"bytes"
 	"os"
 	"path/filepath"
-	"strings"
 	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestWriteCSV(t *testing.T) {
@@ -43,39 +44,26 @@ func TestWriteCSV(t *testing.T) {
 			wantEmpty:      true,
 			testReportName: "empty",
 			expected:       filepath.Join(expectPath, "report", "this_should_not_exist.csv"),
-			findings:       []Finding{}},
+			findings:       []Finding{},
+		},
 	}
 
 	for _, test := range tests {
-		tmpfile, err := os.Create(filepath.Join(t.TempDir(), test.testReportName+".csv"))
-		if err != nil {
-			t.Error(err)
-		}
-		err = writeCsv(test.findings, tmpfile)
-		if err != nil {
-			t.Error(err)
-		}
-		got, err := os.ReadFile(tmpfile.Name())
-		if err != nil {
-			t.Error(err)
-		}
-		if test.wantEmpty {
-			if len(got) > 0 {
-				t.Errorf("Expected empty file, got %s", got)
-			}
-			continue
-		}
-		want, err := os.ReadFile(test.expected)
-		if err != nil {
-			t.Error(err)
-		}
-
-		if !bytes.Equal(got, want) {
-			err = os.WriteFile(strings.Replace(test.expected, ".csv", ".got.csv", 1), got, 0644)
-			if err != nil {
-				t.Error(err)
+		t.Run(test.testReportName, func(t *testing.T) {
+			tmpfile, err := os.Create(filepath.Join(t.TempDir(), test.testReportName+".csv"))
+			require.NoError(t, err)
+			err = writeCsv(test.findings, tmpfile)
+			require.NoError(t, err)
+			assert.FileExists(t, tmpfile.Name())
+			got, err := os.ReadFile(tmpfile.Name())
+			require.NoError(t, err)
+			if test.wantEmpty {
+				assert.Empty(t, got)
+				return
 			}
-			t.Errorf("got %s, want %s", string(got), string(want))
-		}
+			want, err := os.ReadFile(test.expected)
+			require.NoError(t, err)
+			assert.Equal(t, want, got)
+		})
 	}
 }

+ 10 - 16
report/finding_test.go

@@ -1,6 +1,10 @@
 package report
 
-import "testing"
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
 
 func TestRedact(t *testing.T) {
 	tests := []struct {
@@ -19,12 +23,8 @@ func TestRedact(t *testing.T) {
 	for _, test := range tests {
 		for _, f := range test.findings {
 			f.Redact(100)
-			if f.Secret != "REDACTED" {
-				t.Error("redact not redacting: ", f.Secret)
-			}
-			if f.Match != "line containing REDACTED" {
-				t.Error("redact not redacting: ", f.Secret)
-			}
+			assert.Equal(t, "REDACTED", f.Secret)
+			assert.Equal(t, "line containing REDACTED", f.Match)
 		}
 	}
 }
@@ -57,12 +57,8 @@ func TestMask(t *testing.T) {
 			f := test.finding
 			e := test.expect
 			f.Redact(test.percent)
-			if f.Secret != e.Secret {
-				t.Error("redact not redacting: ", f.Secret)
-			}
-			if f.Match != e.Match {
-				t.Error("redact not redacting: ", f.Match)
-			}
+			assert.Equal(t, e.Secret, f.Secret)
+			assert.Equal(t, e.Match, f.Match)
 		})
 	}
 }
@@ -82,9 +78,7 @@ func TestMaskSecret(t *testing.T) {
 	for name, test := range tests {
 		t.Run(name, func(t *testing.T) {
 			got := maskSecret(test.secret, test.percent)
-			if got != test.expect {
-				t.Error("redact not redacting: ", got)
-			}
+			assert.Equal(t, test.expect, got)
 		})
 	}
 }

+ 18 - 31
report/json_test.go

@@ -1,11 +1,12 @@
 package report
 
 import (
-	"bytes"
 	"os"
 	"path/filepath"
-	"strings"
 	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestWriteJSON(t *testing.T) {
@@ -47,35 +48,21 @@ func TestWriteJSON(t *testing.T) {
 	}
 
 	for _, test := range tests {
-		tmpfile, err := os.Create(filepath.Join(t.TempDir(), test.testReportName+".json"))
-		if err != nil {
-			t.Error(err)
-		}
-		err = writeJson(test.findings, tmpfile)
-		if err != nil {
-			t.Error(err)
-		}
-		got, err := os.ReadFile(tmpfile.Name())
-		if err != nil {
-			t.Error(err)
-		}
-		if test.wantEmpty {
-			if len(got) > 0 {
-				t.Errorf("Expected empty file, got %s", got)
-			}
-			continue
-		}
-		want, err := os.ReadFile(test.expected)
-		if err != nil {
-			t.Error(err)
-		}
-
-		if !bytes.Equal(got, want) {
-			err = os.WriteFile(strings.Replace(test.expected, ".json", ".got.json", 1), got, 0644)
-			if err != nil {
-				t.Error(err)
+		t.Run(test.testReportName, func(t *testing.T) {
+			tmpfile, err := os.Create(filepath.Join(t.TempDir(), test.testReportName+".json"))
+			require.NoError(t, err)
+			err = writeJson(test.findings, tmpfile)
+			require.NoError(t, err)
+			assert.FileExists(t, tmpfile.Name())
+			got, err := os.ReadFile(tmpfile.Name())
+			require.NoError(t, err)
+			if test.wantEmpty {
+				assert.Empty(t, got)
+				return
 			}
-			t.Errorf("got %s, want %s", string(got), string(want))
-		}
+			want, err := os.ReadFile(test.expected)
+			require.NoError(t, err)
+			assert.Equal(t, want, got)
+		})
 	}
 }

+ 11 - 26
report/junit_test.go

@@ -3,8 +3,10 @@ package report
 import (
 	"os"
 	"path/filepath"
-	"strings"
 	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestWriteJunit(t *testing.T) {
@@ -64,36 +66,19 @@ func TestWriteJunit(t *testing.T) {
 	}
 
 	for _, test := range tests {
-		// create tmp file using os.TempDir()
 		tmpfile, err := os.Create(filepath.Join(t.TempDir(), test.testReportName+".xml"))
-		if err != nil {
-			t.Fatal(err)
-		}
+		require.NoError(t, err)
 		err = writeJunit(test.findings, tmpfile)
-		if err != nil {
-			t.Fatal(err)
-		}
+		require.NoError(t, err)
+		assert.FileExists(t, tmpfile.Name())
 		got, err := os.ReadFile(tmpfile.Name())
-		if err != nil {
-			t.Fatal(err)
-		}
+		require.NoError(t, err)
 		if test.wantEmpty {
-			if len(got) > 0 {
-				t.Errorf("Expected empty file, got %s", got)
-			}
-			continue
+			assert.Empty(t, got)
+			return
 		}
 		want, err := os.ReadFile(test.expected)
-		if err != nil {
-			t.Fatal(err)
-		}
-
-		if string(got) != string(want) {
-			err = os.WriteFile(strings.Replace(test.expected, ".xml", ".got.xml", 1), got, 0644)
-			if err != nil {
-				t.Fatal(err)
-			}
-			t.Errorf("got %s, want %s", string(got), string(want))
-		}
+		require.NoError(t, err)
+		assert.Equal(t, want, got)
 	}
 }

+ 16 - 22
report/report_test.go

@@ -6,6 +6,9 @@ import (
 	"strconv"
 	"testing"
 
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+
 	"github.com/zricethezav/gitleaks/v8/config"
 )
 
@@ -95,28 +98,19 @@ func TestReport(t *testing.T) {
 	}
 
 	for i, test := range tests {
-		tmpfile, err := os.Create(filepath.Join(t.TempDir(), strconv.Itoa(i)+test.ext))
-		if err != nil {
-			t.Error(err)
-		}
-		err = Write(test.findings, config.Config{}, test.ext, tmpfile.Name())
-		if err != nil {
-			t.Error(err)
-		}
-		got, err := os.ReadFile(tmpfile.Name())
-		if err != nil {
-			t.Error(err)
-		}
-
-		if len(got) == 0 && !test.wantEmpty {
-			t.Errorf("got empty file with extension " + test.ext)
-		}
-
-		if test.wantEmpty {
-			if len(got) > 0 {
-				t.Errorf("Expected empty file, got %s", got)
+		t.Run(test.ext, func(t *testing.T) {
+			tmpfile, err := os.Create(filepath.Join(t.TempDir(), strconv.Itoa(i)+test.ext))
+			require.NoError(t, err)
+			err = Write(test.findings, config.Config{}, test.ext, tmpfile.Name())
+			require.NoError(t, err)
+			got, err := os.ReadFile(tmpfile.Name())
+			require.NoError(t, err)
+			assert.FileExists(t, tmpfile.Name())
+			if test.wantEmpty {
+				assert.Empty(t, got)
+				return
 			}
-			continue
-		}
+			assert.NotEmpty(t, got)
+		})
 	}
 }

+ 29 - 51
report/sarif_test.go

@@ -1,14 +1,14 @@
 package report
 
 import (
-	"bytes"
-	"fmt"
 	"os"
 	"path/filepath"
-	"strings"
 	"testing"
 
 	"github.com/spf13/viper"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+
 	"github.com/zricethezav/gitleaks/v8/config"
 )
 
@@ -49,56 +49,34 @@ func TestWriteSarif(t *testing.T) {
 	}
 
 	for _, test := range tests {
-		tmpfile, err := os.Create(filepath.Join(t.TempDir(), test.testReportName+".json"))
-		if err != nil {
-			t.Error(err)
-		}
-		viper.Reset()
-		viper.AddConfigPath(configPath)
-		viper.SetConfigName(test.cfgName)
-		viper.SetConfigType("toml")
-		err = viper.ReadInConfig()
-		if err != nil {
-			t.Error(err)
-		}
-
-		var vc config.ViperConfig
-		err = viper.Unmarshal(&vc)
-		if err != nil {
-			t.Error(err)
-		}
+		t.Run(test.cfgName, func(t *testing.T) {
+			tmpfile, err := os.Create(filepath.Join(t.TempDir(), test.testReportName+".json"))
+			require.NoError(t, err)
+			viper.Reset()
+			viper.AddConfigPath(configPath)
+			viper.SetConfigName(test.cfgName)
+			viper.SetConfigType("toml")
+			err = viper.ReadInConfig()
+			require.NoError(t, err)
 
-		cfg, err := vc.Translate()
-		if err != nil {
-			t.Error(err)
-		}
-		err = writeSarif(cfg, test.findings, tmpfile)
-		fmt.Println(cfg)
-		if err != nil {
-			t.Error(err)
-		}
-		got, err := os.ReadFile(tmpfile.Name())
-		if err != nil {
-			t.Error(err)
-		}
-		if test.wantEmpty {
-			if len(got) > 0 {
-				t.Errorf("Expected empty file, got %s", got)
-			}
-			continue
-		}
-		want, err := os.ReadFile(test.expected)
-		if err != nil {
-			t.Error(err)
-		}
+			var vc config.ViperConfig
+			err = viper.Unmarshal(&vc)
+			require.NoError(t, err)
 
-		if !bytes.Equal(got, want) {
-			err = os.WriteFile(strings.Replace(test.expected, ".sarif", ".got.sarif", 1), got, 0644)
-			if err != nil {
-				t.Error(err)
+			cfg, err := vc.Translate()
+			require.NoError(t, err)
+			err = writeSarif(cfg, test.findings, tmpfile)
+			require.NoError(t, err)
+			assert.FileExists(t, tmpfile.Name())
+			got, err := os.ReadFile(tmpfile.Name())
+			require.NoError(t, err)
+			if test.wantEmpty {
+				assert.Empty(t, got)
+				return
 			}
-			t.Errorf("got %s, want %s", string(got), string(want))
-		}
-
+			want, err := os.ReadFile(test.expected)
+			require.NoError(t, err)
+			assert.Equal(t, want, got)
+		})
 	}
 }