ソースを参照

one tmp dir to rule them all, I think

Zach 9 ヶ月 前
コミット
e0ea9cefde
7 ファイル変更58 行追加41 行削除
  1. 2 0
      cmd/root.go
  2. 10 10
      detect/archive.go
  3. 15 21
      detect/detect.go
  4. 4 8
      detect/directory.go
  5. 24 0
      detect/fragment.go
  6. 2 2
      detect/git.go
  7. 1 0
      sources/git.go

+ 2 - 0
cmd/root.go

@@ -421,6 +421,8 @@ func findingSummaryAndExit(detector *detect.Detector, findings []report.Finding,
 		diagnosticsManager.StopDiagnostics()
 		diagnosticsManager.StopDiagnostics()
 	}
 	}
 
 
+	detector.Cleanup()
+
 	totalBytes := detector.TotalBytes.Load()
 	totalBytes := detector.TotalBytes.Load()
 	bytesMsg := fmt.Sprintf("scanned ~%d bytes (%s)", totalBytes, bytesConvert(totalBytes))
 	bytesMsg := fmt.Sprintf("scanned ~%d bytes (%s)", totalBytes, bytesConvert(totalBytes))
 	if err == nil {
 	if err == nil {

+ 10 - 10
detect/archive.go

@@ -12,10 +12,10 @@ import (
 	"github.com/zricethezav/gitleaks/v8/sources"
 	"github.com/zricethezav/gitleaks/v8/sources"
 )
 )
 
 
-// IsArchive asks archives.Identify (with a nil stream, so only the filename)
+// isArchive asks archives.Identify (with a nil stream, so only the filename)
 // whether this file would be handled by an Extractor. If Identify returns
 // whether this file would be handled by an Extractor. If Identify returns
 // a Format implementing archives.Extractor, we treat it as an archive.
 // a Format implementing archives.Extractor, we treat it as an archive.
-func IsArchive(path string) bool {
+func isArchive(path string) bool {
 	format, _, err := archives.Identify(context.Background(), path, nil)
 	format, _, err := archives.Identify(context.Background(), path, nil)
 	if err != nil {
 	if err != nil {
 		// no matching format at all
 		// no matching format at all
@@ -27,15 +27,15 @@ func IsArchive(path string) bool {
 
 
 // ExtractArchive extracts all files from archivePath into a temp dir.
 // ExtractArchive extracts all files from archivePath into a temp dir.
 // Returns the list of ScanTargets (with real file paths) and the temp dir for cleanup.
 // Returns the list of ScanTargets (with real file paths) and the temp dir for cleanup.
-func ExtractArchive(archivePath string) ([]sources.ScanTarget, string, error) {
-	tmpDir, err := os.MkdirTemp("", "gitleaks-archive-")
+func extractArchive(archivePath string) ([]sources.ScanTarget, string, error) {
+	tmpArchiveDir, err := os.MkdirTemp(tmpDir, "archive-*")
 	if err != nil {
 	if err != nil {
-		return nil, "", err
+		return nil, "", fmt.Errorf("creating temp dir for archive: %w", err)
 	}
 	}
 
 
 	f, err := os.Open(archivePath)
 	f, err := os.Open(archivePath)
 	if err != nil {
 	if err != nil {
-		os.RemoveAll(tmpDir)
+		os.RemoveAll(tmpArchiveDir)
 		return nil, "", err
 		return nil, "", err
 	}
 	}
 	defer f.Close()
 	defer f.Close()
@@ -43,13 +43,13 @@ func ExtractArchive(archivePath string) ([]sources.ScanTarget, string, error) {
 	ctx := context.Background()
 	ctx := context.Background()
 	format, stream, err := archives.Identify(ctx, archivePath, f)
 	format, stream, err := archives.Identify(ctx, archivePath, f)
 	if err != nil {
 	if err != nil {
-		os.RemoveAll(tmpDir)
+		os.RemoveAll(tmpArchiveDir)
 		return nil, "", err
 		return nil, "", err
 	}
 	}
 
 
 	extractor, ok := format.(archives.Extractor)
 	extractor, ok := format.(archives.Extractor)
 	if !ok {
 	if !ok {
-		os.RemoveAll(tmpDir)
+		os.RemoveAll(tmpArchiveDir)
 		return nil, "", fmt.Errorf("format %T is not extractable", format)
 		return nil, "", fmt.Errorf("format %T is not extractable", format)
 	}
 	}
 
 
@@ -74,7 +74,7 @@ func ExtractArchive(archivePath string) ([]sources.ScanTarget, string, error) {
 		}
 		}
 		defer r.Close()
 		defer r.Close()
 
 
-		outPath := filepath.Join(tmpDir, file.Name())
+		outPath := filepath.Join(tmpArchiveDir, file.Name())
 		if err := os.MkdirAll(filepath.Dir(outPath), 0o755); err != nil {
 		if err := os.MkdirAll(filepath.Dir(outPath), 0o755); err != nil {
 			return err
 			return err
 		}
 		}
@@ -92,5 +92,5 @@ func ExtractArchive(archivePath string) ([]sources.ScanTarget, string, error) {
 		return nil
 		return nil
 	})
 	})
 
 
-	return targets, tmpDir, err
+	return targets, tmpArchiveDir, err
 }
 }

+ 15 - 21
detect/detect.go

@@ -36,8 +36,17 @@ const (
 var (
 var (
 	newLineRegexp = regexp.MustCompile("\n")
 	newLineRegexp = regexp.MustCompile("\n")
 	isWindows     = runtime.GOOS == "windows"
 	isWindows     = runtime.GOOS == "windows"
+	tmpDir        string
 )
 )
 
 
+func init() {
+	var err error
+	tmpDir, err = os.MkdirTemp("", "gitleaks-*")
+	if err != nil {
+		logging.Fatal().Err(err).Msg("failed to create temp dir for archive extraction")
+	}
+}
+
 // Detector is the main detector struct
 // Detector is the main detector struct
 type Detector struct {
 type Detector struct {
 	// Config is the configuration for the detector
 	// Config is the configuration for the detector
@@ -102,27 +111,12 @@ type Detector struct {
 	TotalBytes atomic.Uint64
 	TotalBytes atomic.Uint64
 }
 }
 
 
-// Fragment contains the data to be scanned
-type Fragment struct {
-	// Raw is the raw content of the fragment
-	Raw string
-
-	Bytes []byte
-
-	// FilePath is the path to the file, if applicable.
-	// The path separator MUST be normalized to `/`.
-	FilePath    string
-	SymlinkFile string
-	// WindowsFilePath is the path with the original separator.
-	// This provides a backwards-compatible solution to https://github.com/gitleaks/gitleaks/issues/1565.
-	WindowsFilePath string `json:"-"` // TODO: remove this in v9.
-
-	// CommitSHA is the SHA of the commit if applicable
-	CommitSHA string
-
-	// newlineIndices is a list of indices of newlines in the raw content.
-	// This is used to calculate the line location of a finding
-	newlineIndices [][]int
+func (d *Detector) Cleanup() {
+	if tmpDir != "" {
+		if err := os.RemoveAll(tmpDir); err != nil {
+			logging.Warn().Err(err).Msg("failed to remove temp dir for archive extraction")
+		}
+	}
 }
 }
 
 
 // NewDetector creates a new detector with the given config
 // NewDetector creates a new detector with the given config

+ 4 - 8
detect/directory.go

@@ -43,10 +43,10 @@ func (d *Detector) detectScanTarget(scanTarget sources.ScanTarget) error {
 	logger.Trace().Msg("Scanning path")
 	logger.Trace().Msg("Scanning path")
 
 
 	// --- Archive branch: extract and reschedule children ---
 	// --- Archive branch: extract and reschedule children ---
-	if IsArchive(scanTarget.Path) {
-		logger.Info().Msg("Found archive")
+	if isArchive(scanTarget.Path) {
+		logger.Debug().Msg("Found archive")
 
 
-		targets, tmpdir, err := ExtractArchive(scanTarget.Path)
+		targets, tmpArchiveDir, err := extractArchive(scanTarget.Path)
 		if err != nil {
 		if err != nil {
 			logger.Warn().Err(err).Msg("Failed to extract archive")
 			logger.Warn().Err(err).Msg("Failed to extract archive")
 			return nil
 			return nil
@@ -55,7 +55,7 @@ func (d *Detector) detectScanTarget(scanTarget sources.ScanTarget) error {
 		for _, t := range targets {
 		for _, t := range targets {
 			t := t
 			t := t
 			// compute path INSIDE this archive
 			// compute path INSIDE this archive
-			rel, rerr := filepath.Rel(tmpdir, t.Path)
+			rel, rerr := filepath.Rel(tmpArchiveDir, t.Path)
 			if rerr != nil {
 			if rerr != nil {
 				rel = filepath.Base(t.Path)
 				rel = filepath.Base(t.Path)
 			}
 			}
@@ -73,10 +73,6 @@ func (d *Detector) detectScanTarget(scanTarget sources.ScanTarget) error {
 			})
 			})
 		}
 		}
 
 
-		// cleanup extraction directory
-		// if err := os.RemoveAll(tmpdir); err != nil {
-		// 	logger.Warn().Err(err).Msg("Failed to remove tempdir")
-		// }
 		return nil
 		return nil
 	}
 	}
 
 

+ 24 - 0
detect/fragment.go

@@ -0,0 +1,24 @@
+package detect
+
+// Fragment contains the data to be scanned
+type Fragment struct {
+	// Raw is the raw content of the fragment
+	Raw string
+
+	Bytes []byte
+
+	// FilePath is the path to the file, if applicable.
+	// The path separator MUST be normalized to `/`.
+	FilePath    string
+	SymlinkFile string
+	// WindowsFilePath is the path with the original separator.
+	// This provides a backwards-compatible solution to https://github.com/gitleaks/gitleaks/issues/1565.
+	WindowsFilePath string `json:"-"` // TODO: remove this in v9.
+
+	// CommitSHA is the SHA of the commit if applicable
+	CommitSHA string
+
+	// newlineIndices is a list of indices of newlines in the raw content.
+	// This is used to calculate the line location of a finding
+	newlineIndices [][]int
+}

+ 2 - 2
detect/git.go

@@ -46,7 +46,7 @@ func (d *Detector) DetectGit(cmd *sources.GitCmd, remote *RemoteInfo) ([]report.
 				}
 				}
 			}
 			}
 
 
-			if IsArchive(gitdiffFile.NewName) {
+			if isArchive(gitdiffFile.NewName) {
 				// Check if commit is allowed
 				// Check if commit is allowed
 				d.Sema.Go(func() error {
 				d.Sema.Go(func() error {
 					// Check out the archive blob to disk
 					// Check out the archive blob to disk
@@ -57,7 +57,7 @@ func (d *Detector) DetectGit(cmd *sources.GitCmd, remote *RemoteInfo) ([]report.
 					}
 					}
 					defer os.Remove(archivePath)
 					defer os.Remove(archivePath)
 
 
-					targets, tmpDir, err := ExtractArchive(archivePath)
+					targets, tmpDir, err := extractArchive(archivePath)
 					if err != nil {
 					if err != nil {
 						os.RemoveAll(tmpDir)
 						os.RemoveAll(tmpDir)
 						logging.Warn().Err(err).Msg("failed to extract archive")
 						logging.Warn().Err(err).Msg("failed to extract archive")

+ 1 - 0
sources/git.go

@@ -141,6 +141,7 @@ func NewGitDiffCmd(source string, staged bool) (*GitCmd, error) {
 func (g *GitCmd) CheckoutBlob(commit, filepathInRepo string) (string, error) {
 func (g *GitCmd) CheckoutBlob(commit, filepathInRepo string) (string, error) {
 	// Create a temp file with the same extension as the blob, if possible
 	// Create a temp file with the same extension as the blob, if possible
 	ext := filepath.Ext(filepathInRepo)
 	ext := filepath.Ext(filepathInRepo)
+	// tmpDir, err := os.MkdirTemp("gitleaks", "archive-*")
 	tmpFile, err := os.CreateTemp("", "gitleaks-blob-*"+ext)
 	tmpFile, err := os.CreateTemp("", "gitleaks-blob-*"+ext)
 	if err != nil {
 	if err != nil {
 		return "", fmt.Errorf("creating temp file for blob: %w", err)
 		return "", fmt.Errorf("creating temp file for blob: %w", err)