Просмотр исходного кода

Resource fix (#102)

* no more dangling goroutines
Zachary Rice 7 лет назад
Родитель
Сommit
cc6d698e63
3 измененных файлов с 35 добавлено и 34 удалено
  1. 4 0
      CHANGELOG.md
  2. 1 1
      gitleaks_test.go
  3. 30 33
      main.go

+ 4 - 0
CHANGELOG.md

@@ -1,6 +1,10 @@
 CHANGELOG
 =========
 
+1.7.2
+-----
+- Fixing dangling goroutines, removing channel messaging
+
 1.7.1
 -----
 - Fixing bug where single repos were not being audited

+ 1 - 1
gitleaks_test.go

@@ -499,7 +499,7 @@ func TestAuditRepo(t *testing.T) {
 			description: "two leaks present limit goroutines",
 			numLeaks:    2,
 			testOpts: Options{
-				MaxGoRoutines: 2,
+				MaxGoRoutines: 4,
 			},
 		},
 		{

+ 30 - 33
main.go

@@ -310,7 +310,6 @@ func getRepo() (Repo, error) {
 		err  error
 		repo *git.Repository
 	)
-
 	if opts.Disk {
 		log.Infof("cloning %s", opts.Repo)
 		cloneTarget := fmt.Sprintf("%s/%x", dir, md5.Sum([]byte(fmt.Sprintf("%s%s", opts.GithubUser, opts.Repo))))
@@ -357,21 +356,25 @@ func getRepo() (Repo, error) {
 // auditRef audits a git reference
 // TODO: need to add a layer of parallelism here and a cache for parent+child commits so we don't
 // double dip
-func auditRef(repo Repo, ref *plumbing.Reference, commitWg *sync.WaitGroup, commitChan chan []Leak) error {
+func auditRef(repo Repo, ref *plumbing.Reference) []Leak {
 	var (
 		err        error
 		prevCommit *object.Commit
 		semaphore  chan bool
 		repoName   string
+		leaks      []Leak
+		commitWg   sync.WaitGroup
+		mutex      = &sync.Mutex{}
 	)
 	repoName = repo.name
 	if opts.MaxGoRoutines != 0 {
 		maxGo = opts.MaxGoRoutines
 	}
+
 	semaphore = make(chan bool, maxGo)
 	cIter, err := repo.repository.Log(&git.LogOptions{From: ref.Hash()})
 	if err != nil {
-		return err
+		return nil
 	}
 	err = cIter.ForEach(func(c *object.Commit) error {
 		if c.Hash.String() == opts.Commit {
@@ -381,18 +384,16 @@ func auditRef(repo Repo, ref *plumbing.Reference, commitWg *sync.WaitGroup, comm
 			log.Infof("skipping commit: %s\n", c.Hash.String())
 			return nil
 		}
-
-		semaphore <- true
 		commitWg.Add(1)
+		semaphore <- true
 		go func(c *object.Commit, prevCommit *object.Commit) {
 			var (
-				leaks    []Leak
 				filePath string
 				skipFile bool
 			)
 			defer func() {
+				commitWg.Done()
 				<-semaphore
-				commitChan <- leaks
 				if r := recover(); r != nil {
 					log.Warnf("recoverying from panic on commit %s, likely large diff causing panic", c.Hash.String())
 				}
@@ -406,7 +407,12 @@ func auditRef(repo Repo, ref *plumbing.Reference, commitWg *sync.WaitGroup, comm
 					if err != nil {
 						return err
 					}
-					leaks = append(leaks, checkDiff(content, c, file.Name, string(ref.Name()), repoName)...)
+					chunkLeaks := checkDiff(content, c, file.Name, string(ref.Name()), repoName)
+					for _, leak := range chunkLeaks {
+						mutex.Lock()
+						leaks = append(leaks, leak)
+						mutex.Unlock()
+					}
 					return nil
 				})
 				if err != nil {
@@ -441,7 +447,12 @@ func auditRef(repo Repo, ref *plumbing.Reference, commitWg *sync.WaitGroup, comm
 					for _, chunk := range chunks {
 						if chunk.Type() == 1 || chunk.Type() == 2 {
 							// only check if adding or removing
-							leaks = append(leaks, checkDiff(chunk.Content(), prevCommit, filePath, string(ref.Name()), repoName)...)
+							chunkLeaks := checkDiff(chunk.Content(), prevCommit, filePath, string(ref.Name()), repoName)
+							for _, leak := range chunkLeaks {
+								mutex.Lock()
+								leaks = append(leaks, leak)
+								mutex.Unlock()
+							}
 						}
 					}
 				}
@@ -450,16 +461,16 @@ func auditRef(repo Repo, ref *plumbing.Reference, commitWg *sync.WaitGroup, comm
 		prevCommit = c
 		return nil
 	})
-	return nil
+	commitWg.Wait()
+	return leaks
 }
 
 // auditRepo performs an audit on a repository checking for regex matching and ignoring
 // files and regexes that are whitelisted
 func auditRepo(repo Repo) ([]Leak, error) {
 	var (
-		err      error
-		leaks    []Leak
-		commitWg sync.WaitGroup
+		err   error
+		leaks []Leak
 	)
 
 	ref, err := repo.repository.Head()
@@ -467,9 +478,6 @@ func auditRepo(repo Repo) ([]Leak, error) {
 		return leaks, err
 	}
 
-	// leak messaging
-	commitChan := make(chan []Leak, 1)
-
 	if opts.AuditAllRefs {
 		skipBranch := false
 		refs, err := repo.repository.Storer.IterReferences()
@@ -486,7 +494,10 @@ func auditRepo(repo Repo) ([]Leak, error) {
 				skipBranch = false
 				return nil
 			}
-			auditRef(repo, ref, &commitWg, commitChan)
+			branchLeaks := auditRef(repo, ref)
+			for _, leak := range branchLeaks {
+				leaks = append(leaks, leak)
+			}
 			return nil
 		})
 	} else {
@@ -505,22 +516,8 @@ func auditRepo(repo Repo) ([]Leak, error) {
 				return nil, nil
 			}
 		}
-		auditRef(repo, ref, &commitWg, commitChan)
+		leaks = auditRef(repo, ref)
 	}
-
-	go func() {
-		for commitLeaks := range commitChan {
-			if commitLeaks != nil {
-				for _, leak := range commitLeaks {
-					leaks = append(leaks, leak)
-				}
-
-			}
-			commitWg.Done()
-		}
-	}()
-
-	commitWg.Wait()
 	return leaks, err
 }
 
@@ -871,7 +868,7 @@ func getSSHAuth() (*ssh.PublicKeys, error) {
 	return sshAuth, err
 }
 
-func (leak *Leak) log() {
+func (leak Leak) log() {
 	b, _ := json.MarshalIndent(leak, "", "   ")
 	fmt.Println(string(b))
 }