Jelajahi Sumber

chore: Cleanup argument handling (#571)

James Read 1 tahun lalu
induk
melakukan
eb2721c023

+ 1 - 1
.github/workflows/devskim.yml

@@ -16,7 +16,7 @@ on:
 jobs:
   lint:
     name: DevSkim
-    runs-on: ubuntu-20.04
+    runs-on: ubuntu-latest
     permissions:
       actions: read
       contents: read

+ 46 - 47
service/internal/executor/arguments.go

@@ -11,6 +11,7 @@ import (
 	"regexp"
 	"strings"
 	"time"
+	"fmt"
 )
 
 var (
@@ -24,41 +25,34 @@ var (
 	}
 )
 
-func parseCommandForReplacements(rawShellCommand string, values map[string]string) (string, map[string]string, error) {
+func parseCommandForReplacements(shellCommand string, values map[string]string) (string, error) {
 	r := regexp.MustCompile("{{ *?([a-zA-Z0-9_]+?) *?}}")
-	foundArgumentNames := r.FindAllStringSubmatch(rawShellCommand, -1)
-
-	usedArguments := make(map[string]string)
+	foundArgumentNames := r.FindAllStringSubmatch(shellCommand, -1)
 
 	for _, match := range foundArgumentNames {
 		argName := match[1]
 		argValue, argProvided := values[argName]
 
 		if !argProvided {
-			return "", nil, errors.New("Required arg not provided: " + argName)
+			return "", errors.New("Required arg not provided: " + argName)
 		}
 
-		usedArguments[argName] = argValue
-
-		rawShellCommand = strings.ReplaceAll(rawShellCommand, match[0], argValue)
+		shellCommand = strings.ReplaceAll(shellCommand, match[0], argValue)
 	}
 
-	return rawShellCommand, usedArguments, nil
+	return shellCommand, nil
 }
 
-func parseActionArguments(values map[string]string, action *config.Action, actionTitle string, entityPrefix string) (string, error) {
+func parseActionArguments(values map[string]string, action *config.Action, entityPrefix string) (string, error) {
 	log.WithFields(log.Fields{
-		"actionTitle": actionTitle,
+		"actionTitle": action.Title,
 		"cmd":         action.Shell,
 	}).Infof("Action parse args - Before")
 
-	rawShellCommand, usedArgs, err := parseCommandForReplacements(action.Shell, values)
-
-	if err != nil {
-		return "", err
-	}
+	for _, arg := range action.Arguments {
+		argName := arg.Name
+		argValue := values[argName]
 
-	for argName, argValue := range usedArgs {
 		err := typecheckActionArgument(argName, argValue, action)
 
 		if err != nil {
@@ -71,14 +65,19 @@ func parseActionArguments(values map[string]string, action *config.Action, actio
 		}).Debugf("Arg assigned")
 	}
 
-	rawShellCommand = sv.ReplaceEntityVars(entityPrefix, rawShellCommand)
+	parsedShellCommand, err := parseCommandForReplacements(action.Shell, values)
+	parsedShellCommand = sv.ReplaceEntityVars(entityPrefix, parsedShellCommand)
+
+	if err != nil {
+		return "", err
+	}
 
 	log.WithFields(log.Fields{
-		"actionTitle": actionTitle,
-		"cmd":         rawShellCommand,
+		"actionTitle": action.Title,
+		"cmd":         parsedShellCommand,
 	}).Infof("Action parse args - After")
 
-	return rawShellCommand, nil
+	return parsedShellCommand, nil
 }
 
 func typecheckActionArgument(name string, value string, action *config.Action) error {
@@ -99,6 +98,28 @@ func typecheckActionArgument(name string, value string, action *config.Action) e
 	return TypeSafetyCheck(name, value, arg.Type)
 }
 
+// TypeSafetyCheck checks argument values match a specific type. The types are
+// defined in typecheckRegex, and, you guessed it, uses regex to check for allowed
+// characters.
+//
+//gocyclo:ignore
+func TypeSafetyCheck(name string, value string, argumentType string) error {
+	switch argumentType {
+	case "password":
+		return nil
+	case "raw_string_multiline":
+		return nil
+	case "email":
+		return typeSafetyCheckEmail(value)
+	case "url":
+		return typeSafetyCheckUrl(value)
+	case "datetime":
+		return typeSafetyCheckDatetime(value)
+	}
+
+	return typeSafetyCheckRegex(name, value, argumentType)
+}
+
 func typecheckNull(arg *config.ActionArgument) error {
 	if arg.RejectNull {
 		return errors.New("Null values are not allowed")
@@ -135,29 +156,7 @@ func typecheckChoiceEntity(value string, arg *config.ActionArgument) error {
 	return errors.New("argument value cannot be found in entities")
 }
 
-// TypeSafetyCheck checks argument values match a specific type. The types are
-// defined in typecheckRegex, and, you guessed it, uses regex to check for allowed
-// characters.
-//
-//gocyclo:ignore
-func TypeSafetyCheck(name string, value string, argumentType string) error {
-	switch argumentType {
-	case "password":
-		return nil
-	case "raw_string_multiline":
-		return nil
-	case "email":
-		return typeSafetyCheckEmail(name, value)
-	case "url":
-		return typeSafetyCheckUrl(name, value)
-	case "datetime":
-		return typeSafetyCheckDatetime(name, value)
-	}
-
-	return typeSafetyCheckRegex(name, value, argumentType)
-}
-
-func typeSafetyCheckEmail(name string, value string) error {
+func typeSafetyCheckEmail(value string) error {
 	_, err := mail.ParseAddress(value)
 
 	log.Errorf("Email check: %v, %v", err, value)
@@ -169,7 +168,7 @@ func typeSafetyCheckEmail(name string, value string) error {
 	return nil
 }
 
-func typeSafetyCheckDatetime(name string, value string) error {
+func typeSafetyCheckDatetime(value string) error {
 	_, err := time.Parse("2006-01-02T15:04:05", value)
 
 	if err != nil {
@@ -203,13 +202,13 @@ func typeSafetyCheckRegex(name string, value string, argumentType string) error
 			"pattern": pattern,
 		}).Warn("Arg type check safety failure")
 
-		return errors.New("invalid argument, doesn't match " + argumentType)
+		return errors.New(fmt.Sprintf("invalid argument %v, doesn't match %v", name, argumentType))
 	}
 
 	return nil
 }
 
-func typeSafetyCheckUrl(name string, value string) error {
+func typeSafetyCheckUrl(value string) error {
 	_, err := url.ParseRequestURI(value)
 
 	return err

+ 4 - 4
service/internal/executor/arguments_test.go

@@ -33,14 +33,14 @@ func TestArgumentValueNullable(t *testing.T) {
 		"count": "",
 	}
 
-	out, err := parseActionArguments(values, &a1, a1.Title, "")
+	out, err := parseActionArguments(values, &a1, "")
 
 	assert.Equal(t, "echo 'Releasing  hounds'", out)
 	assert.Nil(t, err)
 
 	a1.Arguments[0].RejectNull = true
 
-	_, err = parseActionArguments(values, &a1, a1.Title, "")
+	_, err = parseActionArguments(values, &a1, "")
 
 	assert.NotNil(t, err)
 }
@@ -61,7 +61,7 @@ func TestArgumentNameNumbers(t *testing.T) {
 		"person1name": "Fred",
 	}
 
-	out, err := parseActionArguments(values, &a1, a1.Title, "")
+	out, err := parseActionArguments(values, &a1, "")
 
 	assert.Equal(t, "echo 'Tickling Fred'", out)
 	assert.Nil(t, err)
@@ -81,7 +81,7 @@ func TestArgumentNotProvided(t *testing.T) {
 
 	values := map[string]string{}
 
-	out, err := parseActionArguments(values, &a1, a1.Title, "")
+	out, err := parseActionArguments(values, &a1, "")
 
 	assert.Equal(t, "", out)
 	assert.Equal(t, err.Error(), "Required arg not provided: personName")

+ 4 - 4
service/internal/executor/executor.go

@@ -147,7 +147,7 @@ func (e *Executor) AddListener(m listener) {
 //	count: Number of logs to retrieve
 //
 // Returns: The calculated starting index for pagination
-func getPagingStartIndex(startOffset int64, totalLogCount int64, count int64) int64 {
+func getPagingStartIndex(startOffset int64, totalLogCount int64) int64 {
 	var startIndex int64
 
 	if startOffset <= 0 {
@@ -168,7 +168,7 @@ func (e *Executor) GetLogTrackingIds(startOffset int64, pageCount int64) ([]*Int
 
 	totalLogCount := int64(len(e.logsTrackingIdsByDate))
 
-	startIndex := getPagingStartIndex(startOffset, totalLogCount, pageCount)
+	startIndex := getPagingStartIndex(startOffset, totalLogCount)
 
 	pageCount = min(totalLogCount, pageCount)
 
@@ -402,7 +402,7 @@ func stepParseArgs(req *ExecutionRequest) bool {
 	req.Arguments["ot_executionTrackingId"] = req.TrackingID
 	req.Arguments["ot_username"] = req.AuthenticatedUser.Username
 
-	req.finalParsedCommand, err = parseActionArguments(req.Arguments, req.Action, req.logEntry.ActionTitle, req.EntityPrefix)
+	req.finalParsedCommand, err = parseActionArguments(req.Arguments, req.Action, req.EntityPrefix)
 
 	if err != nil {
 		req.logEntry.Output = err.Error()
@@ -595,7 +595,7 @@ func stepExecAfter(req *ExecutionRequest) bool {
 		"ot_username":            req.AuthenticatedUser.Username,
 	}
 
-	finalParsedCommand, _, err := parseCommandForReplacements(req.Action.ShellAfterCompleted, args)
+	finalParsedCommand, err := parseCommandForReplacements(req.Action.ShellAfterCompleted, args)
 
 	if err != nil {
 		msg := "Could not prepare shellAfterCompleted command: " + err.Error() + "\n"

+ 54 - 7
service/internal/executor/executor_test.go

@@ -82,7 +82,7 @@ func TestArgumentNameCamelCase(t *testing.T) {
 		"personName": "Fred",
 	}
 
-	out, err := parseActionArguments(values, a1, a1.Title, "")
+	out, err := parseActionArguments(values, a1, "")
 
 	assert.Equal(t, "echo 'Tickling Fred'", out)
 	assert.Nil(t, err)
@@ -104,7 +104,7 @@ func TestArgumentNameSnakeCase(t *testing.T) {
 		"person_name": "Fred",
 	}
 
-	out, err := parseActionArguments(values, a1, a1.Title, "")
+	out, err := parseActionArguments(values, a1, "")
 
 	assert.Equal(t, "echo 'Tickling Fred'", out)
 	assert.Nil(t, err)
@@ -173,9 +173,56 @@ func execNewReqAndWait(e *Executor, title string, cfg *config.Config) {
 }
 
 func TestGetPagingIndexes(t *testing.T) {
-	assert.Zero(t, getPagingStartIndex(5, 0, 5), "Testing start index from empty list")
-	assert.Equal(t, int64(4), getPagingStartIndex(5, 10, 5), "Testing start index from mid point")
-	assert.Equal(t, int64(9), getPagingStartIndex(-1, 10, 5), "Testing start index with negative offset")
-	assert.Equal(t, int64(0), getPagingStartIndex(15, 10, 5), "Testing start index with large offset")
-	assert.Equal(t, int64(9), getPagingStartIndex(0, 10, 0), "Testing start index with zero count")
+	assert.Zero(t, getPagingStartIndex(5, 0), "Testing start index from empty list")
+	assert.Equal(t, int64(4), getPagingStartIndex(5, 10), "Testing start index from mid point")
+	assert.Equal(t, int64(9), getPagingStartIndex(-1, 10), "Testing start index with negative offset")
+	assert.Equal(t, int64(0), getPagingStartIndex(15, 10), "Testing start index with large offset")
+	assert.Equal(t, int64(9), getPagingStartIndex(0, 10), "Testing start index with zero count")
+}
+
+func TestUnsetRequiredArgument(t *testing.T) {
+	a1 := &config.Action{
+		Title: "Print your name",
+		Shell: "echo 'Your name is: {{ name }}'",
+		Arguments: []config.ActionArgument{
+			{
+				Name:     "name",
+				Type:     "ascii",
+			},
+		},
+	}
+
+	values := map[string]string{}
+
+	out, err := parseActionArguments(values, a1, "")
+
+	assert.Equal(t, "", out)
+	assert.NotNil(t, err)
+}
+
+func TestUnusedArgumentStillPassesTypeSafetyCheck(t *testing.T) {
+	a1 := &config.Action{
+		Title: "Print your name",
+		Shell: "echo 'Your name is: {{ name }}'",
+		Arguments: []config.ActionArgument{
+			{
+				Name:     "name",
+				Type:     "ascii",
+			},
+			{
+				Name:     "age",
+				Type:     "int",
+			},
+		},
+	}
+
+	values := map[string]string{
+		"name": "Fred",
+		"age":  "Not an integer",
+	}
+
+	out, err := parseActionArguments(values, a1, "")
+
+	assert.Equal(t, "", out)
+	assert.NotNil(t, err)
 }