Browse Source

Fix premature exit for nogit scans, limit goroutines (#619)

* fix premature exit on nogit scan, actually limit concurreny for nogit

* removing files scanned log
Zachary Rice 4 years ago
parent
commit
6ddf27c5c3
2 changed files with 31 additions and 32 deletions
  1. 22 20
      scan/nogit.go
  2. 9 12
      scan/repo.go

+ 22 - 20
scan/nogit.go

@@ -6,6 +6,7 @@ import (
 	"os"
 	"path/filepath"
 	"strings"
+	"sync"
 
 	log "github.com/sirupsen/logrus"
 
@@ -17,15 +18,19 @@ import (
 
 // NoGitScanner is a scanner that absolutely despises git
 type NoGitScanner struct {
-	opts options.Options
-	cfg  config.Config
+	opts     options.Options
+	cfg      config.Config
+	throttle *Throttle
+	mtx      *sync.Mutex
 }
 
 // NewNoGitScanner creates and returns a nogit scanner. This is used for scanning files and directories
 func NewNoGitScanner(opts options.Options, cfg config.Config) *NoGitScanner {
 	ngs := &NoGitScanner{
-		opts: opts,
-		cfg:  cfg,
+		opts:     opts,
+		cfg:      cfg,
+		throttle: NewThrottle(opts),
+		mtx:      &sync.Mutex{},
 	}
 
 	// no-git scans should ignore .git folders by default
@@ -44,7 +49,8 @@ func (ngs *NoGitScanner) Scan() (Report, error) {
 	var scannerReport Report
 
 	g, _ := errgroup.WithContext(context.Background())
-	paths := make(chan string, 100)
+
+	paths := make(chan string)
 
 	g.Go(func() error {
 		defer close(paths)
@@ -60,11 +66,11 @@ func (ngs *NoGitScanner) Scan() (Report, error) {
 			})
 	})
 
-	leaks := make(chan Leak, 100)
-
 	for path := range paths {
 		p := path
+		ngs.throttle.Limit()
 		g.Go(func() error {
+			defer ngs.throttle.Release()
 			if ngs.cfg.Allowlist.FileAllowed(filepath.Base(p)) ||
 				ngs.cfg.Allowlist.PathAllowed(p) {
 				return nil
@@ -84,7 +90,9 @@ func (ngs *NoGitScanner) Scan() (Report, error) {
 
 					leak.Log(ngs.opts)
 
-					leaks <- leak
+					ngs.mtx.Lock()
+					scannerReport.Leaks = append(scannerReport.Leaks, leak)
+					ngs.mtx.Unlock()
 				}
 			}
 
@@ -129,26 +137,20 @@ func (ngs *NoGitScanner) Scan() (Report, error) {
 					leak.LineNumber = lineNumber
 					leak.Rule = rule.Description
 					leak.Tags = strings.Join(rule.Tags, ", ")
-
 					leak.Log(ngs.opts)
 
-					leaks <- leak
+					ngs.mtx.Lock()
+					scannerReport.Leaks = append(scannerReport.Leaks, leak)
+					ngs.mtx.Unlock()
 				}
 			}
 			return f.Close()
 		})
 	}
 
-	go func() {
-		if err := g.Wait(); err != nil {
-			log.Error(err)
-		}
-		close(leaks)
-	}()
-
-	for leak := range leaks {
-		scannerReport.Leaks = append(scannerReport.Leaks, leak)
+	if err := g.Wait(); err != nil {
+		log.Error(err)
 	}
 
-	return scannerReport, g.Wait()
+	return scannerReport, nil
 }

+ 9 - 12
scan/repo.go

@@ -2,6 +2,7 @@ package scan
 
 import (
 	"context"
+	"sync"
 
 	"golang.org/x/sync/errgroup"
 
@@ -21,6 +22,7 @@ type RepoScanner struct {
 	repo     *git.Repository
 	throttle *Throttle
 	repoName string
+	mtx      *sync.Mutex
 }
 
 // NewRepoScanner returns a new repo scanner (go figure). This function also
@@ -32,6 +34,7 @@ func NewRepoScanner(opts options.Options, cfg config.Config, repo *git.Repositor
 		repo:     repo,
 		throttle: NewThrottle(opts),
 		repoName: getRepoName(opts),
+		mtx:      &sync.Mutex{},
 	}
 
 	return rs
@@ -54,7 +57,6 @@ func (rs *RepoScanner) Scan() (Report, error) {
 
 	g, _ := errgroup.WithContext(context.Background())
 	commits = make(chan *object.Commit)
-	leaks := make(chan Leak)
 
 	commitNum := 0
 	g.Go(func() error {
@@ -91,25 +93,20 @@ func (rs *RepoScanner) Scan() (Report, error) {
 				log.Error(err)
 			}
 			for _, leak := range report.Leaks {
-				leaks <- leak
+				rs.mtx.Lock()
+				scannerReport.Leaks = append(scannerReport.Leaks, leak)
+				rs.mtx.Unlock()
 			}
 			return nil
 		})
 	}
 
-	go func() {
-		if err := g.Wait(); err != nil {
-			log.Error(err)
-		}
-		close(leaks)
-	}()
-
-	for leak := range leaks {
-		scannerReport.Leaks = append(scannerReport.Leaks, leak)
+	if err := g.Wait(); err != nil {
+		log.Error(err)
 	}
 
 	scannerReport.Commits = commitNum
-	return scannerReport, g.Wait()
+	return scannerReport, nil
 }
 
 // SetRepoName sets the repo name