Parcourir la source

Add support for streaming DetectReader (#1760)

ahrav il y a 11 mois
Parent
commit
2020e6a9ab
2 fichiers modifiés avec 175 ajouts et 2 suppressions
  1. 78 0
      detect/reader.go
  2. 97 2
      detect/reader_test.go

+ 78 - 0
detect/reader.go

@@ -3,6 +3,7 @@ package detect
 import (
 	"bufio"
 	"bytes"
+	"errors"
 	"io"
 
 	"github.com/zricethezav/gitleaks/v8/report"
@@ -47,3 +48,80 @@ func (d *Detector) DetectReader(r io.Reader, bufSize int) ([]report.Finding, err
 
 	return findings, nil
 }
+
+// StreamDetectReader streams the detection results from the provided io.Reader.
+// It reads data using the specified buffer size (in KB) and processes each chunk through
+// the existing detection logic. Findings are sent down the returned findings channel as soon as
+// they are detected, while a separate error channel signals a terminal error (or nil upon successful completion).
+// The function returns two channels:
+//   - findingsCh: a receive-only channel that emits report.Finding objects as they are found.
+//   - errCh: a receive-only channel that emits a single final error (or nil if no error occurred)
+//     once the stream ends.
+//
+// Recommended Usage:
+//
+//	Since there will only ever be a single value on the errCh, it is recommended to consume the findingsCh
+//	first. Once findingsCh is closed, the consumer should then read from errCh to determine
+//	if the stream completed successfully or if an error occurred.
+//
+//	This design avoids the need for a select loop, keeping client code simple.
+//
+// Example:
+//
+//	// Assume detector is an instance of *Detector and myReader implements io.Reader.
+//	findingsCh, errCh := detector.StreamDetectReader(myReader, 64) // using 64 KB buffer size
+//
+//	// Process findings as they arrive.
+//	for finding := range findingsCh {
+//	    fmt.Printf("Found secret: %+v\n", finding)
+//	}
+//
+//	// After the findings channel is closed, check the final error.
+//	if err := <-errCh; err != nil {
+//	    log.Fatalf("StreamDetectReader encountered an error: %v", err)
+//	} else {
+//	    fmt.Println("Scanning completed successfully.")
+//	}
+func (d *Detector) StreamDetectReader(r io.Reader, bufSize int) (<-chan report.Finding, <-chan error) {
+	findingsCh := make(chan report.Finding, 1)
+	errCh := make(chan error, 1)
+
+	go func() {
+		defer close(findingsCh)
+		defer close(errCh)
+
+		reader := bufio.NewReader(r)
+		buf := make([]byte, 1000*bufSize)
+
+		for {
+			n, err := reader.Read(buf)
+
+			if n > 0 {
+				peekBuf := bytes.NewBuffer(buf[:n])
+				if readErr := readUntilSafeBoundary(reader, n, maxPeekSize, peekBuf); readErr != nil {
+					errCh <- readErr
+					return
+				}
+
+				fragment := Fragment{Raw: peekBuf.String()}
+				for _, finding := range d.Detect(fragment) {
+					findingsCh <- finding
+					if d.Verbose {
+						printFinding(finding, d.NoColor)
+					}
+				}
+			}
+
+			if err != nil {
+				if errors.Is(err, io.EOF) {
+					errCh <- nil
+					return
+				}
+				errCh <- err
+				return
+			}
+		}
+	}()
+
+	return findingsCh, errCh
+}

+ 97 - 2
detect/reader_test.go

@@ -1,13 +1,16 @@
 package detect
 
 import (
+	"bytes"
+	"errors"
 	"io"
 	"strings"
 	"testing"
-
-	"github.com/stretchr/testify/require"
+	"testing/iotest"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"github.com/zricethezav/gitleaks/v8/report"
 )
 
 const secret = "AKIAIRYLJVKMPEGZMPJS"
@@ -15,6 +18,8 @@ const secret = "AKIAIRYLJVKMPEGZMPJS"
 type mockReader struct {
 	data []byte
 	read bool
+
+	errToReturn error
 }
 
 func (r *mockReader) Read(p []byte) (n int, err error) {
@@ -25,6 +30,9 @@ func (r *mockReader) Read(p []byte) (n int, err error) {
 	// Copy data to the provided buffer.
 	n = copy(p, r.data)
 	r.read = true
+	if r.errToReturn != nil {
+		return n, r.errToReturn
+	}
 
 	// Return io.EOF along with the bytes.
 	return n, io.EOF
@@ -66,3 +74,90 @@ func TestDetectReader(t *testing.T) {
 		})
 	}
 }
+
+func TestStreamDetectReader(t *testing.T) {
+	tests := []struct {
+		name          string
+		reader        io.Reader
+		bufSize       int
+		expectedCount int
+		expectError   bool
+	}{
+		{
+			name:          "Single secret streaming",
+			bufSize:       10,
+			expectedCount: 1,
+			reader:        strings.NewReader(secret),
+			expectError:   false,
+		},
+		{
+			name:          "Empty reader",
+			bufSize:       10,
+			expectedCount: 0,
+			reader:        strings.NewReader(""),
+			expectError:   false,
+		},
+		{
+			name:          "Reader returns error",
+			bufSize:       10,
+			expectedCount: 0,
+			reader:        iotest.ErrReader(errors.New("simulated read error")),
+			expectError:   true,
+		},
+		{
+			name:          "Multiple secrets with larger buffer",
+			bufSize:       20,
+			expectedCount: 2,
+			reader:        strings.NewReader(secret + "\n" + secret),
+			expectError:   false,
+		},
+		{
+			name:          "Mock reader with EOF",
+			bufSize:       10,
+			expectedCount: 1,
+			reader:        &mockReader{data: []byte(secret)},
+			expectError:   false,
+		},
+		{
+			name:          "Secret split across boundary",
+			bufSize:       1, // 1KB buffer forces multiple reads
+			expectedCount: 1,
+			reader: io.MultiReader(
+				strings.NewReader(secret[:len(secret)/2]),
+				strings.NewReader(secret[len(secret)/2:])),
+			expectError: false,
+		},
+		{
+			name:          "Reader returns error after first read",
+			bufSize:       1,
+			expectedCount: 0,
+			reader: &mockReader{
+				data:        append(bytes.Repeat([]byte("blah"), 1000), []byte(secret)...),
+				errToReturn: errors.New("simulated read error"),
+			},
+			expectError: true,
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			detector, err := NewDetectorDefaultConfig()
+			require.NoError(t, err)
+
+			findingsCh, errCh := detector.StreamDetectReader(test.reader, test.bufSize)
+			var findings []report.Finding
+			for f := range findingsCh {
+				findings = append(findings, f)
+			}
+			finalErr := <-errCh
+
+			if test.expectError {
+				require.Error(t, finalErr)
+			} else {
+				require.NoError(t, finalErr)
+			}
+
+			assert.Equal(t, test.expectedCount, len(findings))
+		})
+	}
+}