Sfoglia il codice sorgente

refactor(detect): create readUntilSafeBoundary + add tests (#1676)

Richard Gomez 1 anno fa
parent
commit
4c3da6ebb7
5 ha cambiato i file con 190 aggiunte e 62 eliminazioni
  1. 93 54
      detect/directory.go
  2. 72 0
      detect/directory_test.go
  3. 13 7
      detect/reader.go
  4. 8 1
      detect/reader_test.go
  5. 4 0
      detect/utils.go

+ 93 - 54
detect/directory.go

@@ -1,6 +1,7 @@
 package detect
 
 import (
+	"bufio"
 	"bytes"
 	"io"
 	"os"
@@ -49,64 +50,32 @@ func (d *Detector) DetectFiles(paths <-chan sources.ScanTarget) ([]report.Findin
 				}
 			}
 
-			// Buffer to hold file chunks
-			buf := make([]byte, chunkSize)
-			totalLines := 0
+			var (
+				// Buffer to hold file chunks
+				reader     = bufio.NewReaderSize(f, chunkSize)
+				buf        = make([]byte, chunkSize)
+				totalLines = 0
+			)
 			for {
-				n, err := f.Read(buf)
-				if n > 0 {
-					// TODO: optimization could be introduced here
-					if mimetype, err := filetype.Match(buf[:n]); err != nil {
-						return nil
-					} else if mimetype.MIME.Type == "application" {
-						return nil // skip binary files
-					}
-
-					// If the chunk doesn't end in a newline, peek |maxPeekSize| until we find one.
-					// This hopefully avoids splitting
-					// See: https://github.com/gitleaks/gitleaks/issues/1651
-					var (
-						peekBuf      = bytes.NewBuffer(buf[:n])
-						tempBuf      = make([]byte, 1)
-						newlineCount = 0 // Tracks consecutive newlines
-					)
-					for {
-						data := peekBuf.Bytes()
-						if len(data) == 0 {
-							break
-						}
-
-						// Check if the last character is a newline.
-						lastChar := data[len(data)-1]
-						if lastChar == '\n' || lastChar == '\r' {
-							newlineCount++
-
-							// Stop if two consecutive newlines are found
-							if newlineCount >= 2 {
-								break
-							}
-						} else {
-							newlineCount = 0 // Reset if a non-newline character is found
-						}
-
-						// Stop growing the buffer if it reaches maxSize
-						if (peekBuf.Len() - n) >= maxPeekSize {
-							break
-						}
+				n, err := reader.Read(buf)
 
-						// Read additional data into a temporary buffer
-						m, readErr := f.Read(tempBuf)
-						if m > 0 {
-							peekBuf.Write(tempBuf[:m])
+				// "Callers should always process the n > 0 bytes returned before considering the error err."
+				// https://pkg.go.dev/io#Reader
+				if n > 0 {
+					// Only check the filetype at the start of file.
+					if totalLines == 0 {
+						// TODO: could other optimizations be introduced here?
+						if mimetype, err := filetype.Match(buf[:n]); err != nil {
+							return nil
+						} else if mimetype.MIME.Type == "application" {
+							return nil // skip binary files
 						}
+					}
 
-						// Stop if EOF is reached
-						if readErr != nil {
-							if readErr == io.EOF {
-								break
-							}
-							return readErr
-						}
+					// Try to split chunks across large areas of whitespace, if possible.
+					peekBuf := bytes.NewBuffer(buf[:n])
+					if readErr := readUntilSafeBoundary(reader, n, maxPeekSize, peekBuf); readErr != nil {
+						return readErr
 					}
 
 					// Count the number of newlines in this chunk
@@ -145,3 +114,73 @@ func (d *Detector) DetectFiles(paths <-chan sources.ScanTarget) ([]report.Findin
 
 	return d.findings, nil
 }
+
+// readUntilSafeBoundary consumes |f| until it finds two consecutive `\n` characters, up to |maxPeekSize|.
+// This hopefully avoids splitting. (https://github.com/gitleaks/gitleaks/issues/1651)
+func readUntilSafeBoundary(r *bufio.Reader, n int, maxPeekSize int, peekBuf *bytes.Buffer) error {
+	if peekBuf.Len() == 0 {
+		return nil
+	}
+
+	// Does the buffer end in consecutive newlines?
+	var (
+		data         = peekBuf.Bytes()
+		lastChar     = data[len(data)-1]
+		newlineCount = 0 // Tracks consecutive newlines
+	)
+	if isWhitespace(lastChar) {
+		for i := len(data) - 1; i >= 0; i-- {
+			lastChar = data[i]
+			if lastChar == '\n' {
+				newlineCount++
+
+				// Stop if two consecutive newlines are found
+				if newlineCount >= 2 {
+					return nil
+				}
+			} else if lastChar == '\r' || lastChar == ' ' || lastChar == '\t' {
+				// The presence of other whitespace characters (`\r`, ` `, `\t`) shouldn't reset the count.
+				// (Intentionally do nothing.)
+			} else {
+				break
+			}
+		}
+	}
+
+	// If not, read ahead until we (hopefully) find some.
+	newlineCount = 0
+	for {
+		data = peekBuf.Bytes()
+		// Check if the last character is a newline.
+		lastChar = data[len(data)-1]
+		if lastChar == '\n' {
+			newlineCount++
+
+			// Stop if two consecutive newlines are found
+			if newlineCount >= 2 {
+				break
+			}
+		} else if lastChar == '\r' || lastChar == ' ' || lastChar == '\t' {
+			// The presence of other whitespace characters (`\r`, ` `, `\t`) shouldn't reset the count.
+			// (Intentionally do nothing.)
+		} else {
+			newlineCount = 0 // Reset if a non-newline character is found
+		}
+
+		// Stop growing the buffer if it reaches maxSize
+		if (peekBuf.Len() - n) >= maxPeekSize {
+			break
+		}
+
+		// Read additional data into a temporary buffer
+		b, err := r.ReadByte()
+		if err != nil {
+			if err == io.EOF {
+				break
+			}
+			return err
+		}
+		peekBuf.WriteByte(b)
+	}
+	return nil
+}

+ 72 - 0
detect/directory_test.go

@@ -0,0 +1,72 @@
+package detect
+
+import (
+	"bufio"
+	"bytes"
+	"io"
+	"strings"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func Test_readUntilSafeBoundary(t *testing.T) {
+	// Arrange
+	cases := []struct {
+		name     string
+		r        io.Reader
+		expected string
+	}{
+		// Current split is fine, exit early.
+		{
+			name:     "safe original split - LF",
+			r:        strings.NewReader("abc\n\ndefghijklmnop\n\nqrstuvwxyz"),
+			expected: "abc\n\n",
+		},
+		{
+			name:     "safe original split - CRLF",
+			r:        strings.NewReader("a\r\n\r\nbcdefghijklmnop\n"),
+			expected: "a\r\n\r\n",
+		},
+		// Current split is bad, look for a better one.
+		{
+			name:     "safe split - LF",
+			r:        strings.NewReader("abcdefg\nhijklmnop\n\nqrstuvwxyz"),
+			expected: "abcdefg\nhijklmnop\n\n",
+		},
+		{
+			name:     "safe split - CRLF",
+			r:        strings.NewReader("abcdefg\r\nhijklmnop\r\n\r\nqrstuvwxyz"),
+			expected: "abcdefg\r\nhijklmnop\r\n\r\n",
+		},
+		{
+			name:     "safe split - blank line",
+			r:        strings.NewReader("abcdefg\nhijklmnop\n\t  \t\nqrstuvwxyz"),
+			expected: "abcdefg\nhijklmnop\n\t  \t\n",
+		},
+		// Current split is bad, exhaust options.
+		{
+			name:     "no safe split",
+			r:        strings.NewReader("abcdefg\nhijklmnopqrstuvwxyz"),
+			expected: "abcdefg\nhijklmnopqrstuvwx",
+		},
+	}
+
+	for _, c := range cases {
+		t.Run(c.name, func(t *testing.T) {
+			buf := make([]byte, 5)
+			n, err := c.r.Read(buf)
+			require.NoError(t, err)
+
+			// Act
+			reader := bufio.NewReader(c.r)
+			peekBuf := bytes.NewBuffer(buf[:n])
+			err = readUntilSafeBoundary(reader, n, 20, peekBuf)
+			require.NoError(t, err)
+
+			// Assert
+			t.Logf(peekBuf.String())
+			require.Equal(t, c.expected, string(peekBuf.Bytes()))
+		})
+	}
+}

+ 13 - 7
detect/reader.go

@@ -2,6 +2,7 @@ package detect
 
 import (
 	"bufio"
+	"bytes"
 	"io"
 
 	"github.com/zricethezav/gitleaks/v8/report"
@@ -10,18 +11,23 @@ import (
 // DetectReader accepts an io.Reader and a buffer size for the reader in KB
 func (d *Detector) DetectReader(r io.Reader, bufSize int) ([]report.Finding, error) {
 	reader := bufio.NewReader(r)
-	buf := make([]byte, 0, 1000*bufSize)
+	buf := make([]byte, 1000*bufSize)
 	findings := []report.Finding{}
 
 	for {
-		n, err := reader.Read(buf[:cap(buf)])
+		n, err := reader.Read(buf)
 
 		// "Callers should always process the n > 0 bytes returned before considering the error err."
 		// https://pkg.go.dev/io#Reader
 		if n > 0 {
-			buf = buf[:n]
+			// Try to split chunks across large areas of whitespace, if possible.
+			peekBuf := bytes.NewBuffer(buf[:n])
+			if readErr := readUntilSafeBoundary(reader, n, maxPeekSize, peekBuf); readErr != nil {
+				return findings, readErr
+			}
+
 			fragment := Fragment{
-				Raw: string(buf),
+				Raw: peekBuf.String(),
 			}
 			for _, finding := range d.Detect(fragment) {
 				findings = append(findings, finding)
@@ -32,10 +38,10 @@ func (d *Detector) DetectReader(r io.Reader, bufSize int) ([]report.Finding, err
 		}
 
 		if err != nil {
-			if err != io.EOF {
-				return findings, err
+			if err == io.EOF {
+				break
 			}
-			break
+			return findings, err
 		}
 	}
 

+ 8 - 1
detect/reader_test.go

@@ -1,11 +1,12 @@
 package detect
 
 import (
-	"github.com/stretchr/testify/require"
 	"io"
 	"strings"
 	"testing"
 
+	"github.com/stretchr/testify/require"
+
 	"github.com/stretchr/testify/assert"
 )
 
@@ -13,11 +14,17 @@ const secret = "AKIAIRYLJVKMPEGZMPJS"
 
 type mockReader struct {
 	data []byte
+	read bool
 }
 
 func (r *mockReader) Read(p []byte) (n int, err error) {
+	if r.read {
+		return 0, io.EOF
+	}
+
 	// Copy data to the provided buffer.
 	n = copy(p, r.data)
+	r.read = true
 
 	// Return io.EOF along with the bytes.
 	return n, io.EOF

+ 4 - 0
detect/utils.go

@@ -190,3 +190,7 @@ func containsDigit(s string) bool {
 	}
 	return false
 }
+
+func isWhitespace(ch byte) bool {
+	return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r'
+}