ソースを参照

bugfix: Argument confirmations stopped working (#627) (#632)

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
James Read 11 ヶ月 前
コミット
d4d3193c1d

+ 9 - 11
service/internal/executor/arguments.go

@@ -52,7 +52,7 @@ func parseActionArguments(values map[string]string, action *config.Action, entit
 		argName := arg.Name
 		argValue := values[argName]
 
-		err := typecheckActionArgument(argName, argValue, action)
+		err := typecheckActionArgument(&arg, argValue, action)
 
 		if err != nil {
 			return "", err
@@ -102,21 +102,19 @@ func redactShellCommand(shellCommand string, arguments []config.ActionArgument,
 	return shellCommand
 }
 
-func typecheckActionArgument(name string, value string, action *config.Action) error {
-	if name == "" {
-		return fmt.Errorf("argument name cannot be empty")
+func typecheckActionArgument(arg *config.ActionArgument, value string, action *config.Action) error {
+	if arg.Type == "confirmation" {
+		return nil
 	}
 
-	arg := action.FindArg(name)
-
-	if arg == nil {
-		return fmt.Errorf("action arg not defined: %v", name)
+	if arg.Name == "" {
+		return fmt.Errorf("argument name cannot be empty")
 	}
 
-	return typecheckActionArgumentFound(name, value, action, arg)
+	return typecheckActionArgumentFound(value, action, arg)
 }
 
-func typecheckActionArgumentFound(name string, value string, action *config.Action, arg *config.ActionArgument) error {
+func typecheckActionArgumentFound(value string, action *config.Action, arg *config.ActionArgument) error {
 	if value == "" {
 		return typecheckNull(arg)
 	}
@@ -125,7 +123,7 @@ func typecheckActionArgumentFound(name string, value string, action *config.Acti
 		return typecheckChoice(value, arg)
 	}
 
-	return TypeSafetyCheck(name, value, arg.Type)
+	return TypeSafetyCheck(arg.Name, value, arg.Type)
 }
 
 // TypeSafetyCheck checks argument values match a specific type. The types are

+ 415 - 0
service/internal/executor/arguments_test.go

@@ -1,6 +1,9 @@
 package executor
 
 import (
+	"fmt"
+	"strings"
+
 	config "github.com/OliveTin/OliveTin/internal/config"
 
 	"github.com/stretchr/testify/assert"
@@ -166,3 +169,415 @@ func TestRedactShellCommand(t *testing.T) {
 	res = redactShellCommand(cmd, args, values)
 	assert.Equal(t, cmd, res, "Missing password argument should not change the command")
 }
+
+func TestTypeSafetyCheckEmail(t *testing.T) {
+	tests := []struct {
+		name     string
+		field    string
+		value    string
+		hasError bool
+	}{
+		{"Valid simple email", "email", "user@example.com", false},
+		{"Valid email with subdomain", "email", "user@mail.example.com", false},
+		{"Valid email with plus", "email", "user+test@example.com", false},
+		{"Valid email with dash", "email", "user-name@example.com", false},
+		{"Valid email with numbers", "email", "user123@example123.com", false},
+		{"Invalid email no @", "email", "userexample.com", true},
+		{"Invalid email no domain", "email", "user@", true},
+		{"Invalid email no user", "email", "@example.com", true},
+		{"Invalid email spaces", "email", "user name@example.com", true},
+		{"Invalid email double @", "email", "user@@example.com", true},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := TypeSafetyCheck(tt.field, tt.value, "email")
+			if tt.hasError {
+				assert.NotNil(t, err, "Expected error for value '%s'", tt.value)
+			} else {
+				assert.Nil(t, err, "Expected no error for value '%s', but got: %v", tt.value, err)
+			}
+		})
+	}
+}
+
+func TestTypeSafetyCheckDatetime(t *testing.T) {
+	tests := []struct {
+		name     string
+		field    string
+		value    string
+		hasError bool
+	}{
+		{"Valid datetime", "datetime", "2023-12-25T15:30:45", false},
+		{"Valid datetime morning", "datetime", "2023-01-01T00:00:00", false},
+		{"Valid datetime evening", "datetime", "2023-12-31T23:59:59", false},
+		{"Invalid format missing T", "datetime", "2023-12-25 15:30:45", true},
+		{"Invalid format missing seconds", "datetime", "2023-12-25T15:30", true},
+		{"Invalid date", "datetime", "2023-13-25T15:30:45", true},
+		{"Invalid time", "datetime", "2023-12-25T25:30:45", true},
+		{"Random string", "datetime", "not-a-date", true},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := TypeSafetyCheck(tt.field, tt.value, "datetime")
+			if tt.hasError {
+				assert.NotNil(t, err, "Expected error for value '%s'", tt.value)
+			} else {
+				assert.Nil(t, err, "Expected no error for value '%s', but got: %v", tt.value, err)
+			}
+		})
+	}
+}
+
+func TestTypeSafetyCheckRawStringMultiline(t *testing.T) {
+	tests := []struct {
+		name  string
+		field string
+		value string
+	}{
+		{"Simple string", "content", "hello world"},
+		{"Multiline string", "content", "line1\nline2\nline3"},
+		{"String with special chars", "content", "!@#$%^&*()"},
+		{"Unicode string", "content", "héllo wörld 🌍"},
+		{"Very long string", "content", strings.Repeat("a", 1000)},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := TypeSafetyCheck(tt.field, tt.value, "raw_string_multiline")
+			assert.Nil(t, err, "raw_string_multiline should accept any value")
+		})
+	}
+}
+
+func TestTypeSafetyCheckUnicodeIdentifier(t *testing.T) {
+	tests := []struct {
+		name     string
+		field    string
+		value    string
+		hasError bool
+	}{
+		{"Valid unicode identifier", "name", "hello_world", false},
+		{"Valid with numbers", "name", "test123", false},
+		{"Valid with spaces", "name", "hello world", false},
+		{"Valid with path separators", "name", "path/to/file", false},
+		{"Valid with backslashes", "name", "path\\to\\file", false},
+		{"Valid with dots", "name", "file.txt", false},
+		{"Valid with underscores", "name", "my_file_name", false},
+		{"Invalid with special chars", "name", "hello@world", true},
+		{"Invalid with brackets", "name", "hello[world]", true},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := TypeSafetyCheck(tt.field, tt.value, "unicode_identifier")
+			if tt.hasError {
+				assert.NotNil(t, err, "Expected error for value '%s'", tt.value)
+			} else {
+				assert.Nil(t, err, "Expected no error for value '%s', but got: %v", tt.value, err)
+			}
+		})
+	}
+}
+
+func TestTypeSafetyCheckAsciiIdentifier(t *testing.T) {
+	tests := []struct {
+		name     string
+		field    string
+		value    string
+		hasError bool
+	}{
+		{"Valid identifier", "name", "hello_world", false},
+		{"Valid with numbers", "name", "test123", false},
+		{"Valid with dots", "name", "file.txt", false},
+		{"Valid with dashes", "name", "my-file", false},
+		{"Valid with underscores", "name", "my_file", false},
+		{"Invalid with spaces", "name", "hello world", true},
+		{"Invalid with special chars", "name", "hello@world", true},
+		{"Invalid unicode", "name", "héllo", true},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := TypeSafetyCheck(tt.field, tt.value, "ascii_identifier")
+			if tt.hasError {
+				assert.NotNil(t, err, "Expected error for value '%s'", tt.value)
+			} else {
+				assert.Nil(t, err, "Expected no error for value '%s', but got: %v", tt.value, err)
+			}
+		})
+	}
+}
+
+func TestTypeSafetyCheckAsciiSentence(t *testing.T) {
+	tests := []struct {
+		name     string
+		field    string
+		value    string
+		hasError bool
+	}{
+		{"Valid sentence", "text", "Hello world", false},
+		{"Valid with numbers", "text", "Test 123", false},
+		{"Valid with commas", "text", "Hello, world", false},
+		{"Valid with periods", "text", "Hello world.", false},
+		{"Valid with multiple spaces", "text", "Hello  world", false},
+		{"Invalid with special chars", "text", "Hello@world", true},
+		{"Invalid with parentheses", "text", "Hello (world)", true},
+		{"Invalid unicode", "text", "Héllo world", true},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := TypeSafetyCheck(tt.field, tt.value, "ascii_sentence")
+			if tt.hasError {
+				assert.NotNil(t, err, "Expected error for value '%s'", tt.value)
+			} else {
+				assert.Nil(t, err, "Expected no error for value '%s', but got: %v", tt.value, err)
+			}
+		})
+	}
+}
+
+func TestTypecheckActionArgumentEmptyName(t *testing.T) {
+	arg := config.ActionArgument{
+		Name: "",
+		Type: "ascii",
+	}
+	action := config.Action{Title: "Test"}
+
+	err := typecheckActionArgument(&arg, "test", &action)
+	assert.NotNil(t, err)
+	assert.Contains(t, err.Error(), "argument name cannot be empty")
+}
+
+func TestTypecheckActionArgumentConfirmation(t *testing.T) {
+	arg := config.ActionArgument{
+		Name: "confirm",
+		Type: "confirmation",
+	}
+	action := config.Action{Title: "Test"}
+
+	err := typecheckActionArgument(&arg, "any_value", &action)
+	assert.Nil(t, err, "Confirmation type should always pass validation")
+}
+
+func TestParseCommandForReplacements(t *testing.T) {
+	tests := []struct {
+		name           string
+		shellCommand   string
+		values         map[string]string
+		expectedOutput string
+		expectError    bool
+		errorContains  string
+	}{
+		{
+			name:           "Simple replacement",
+			shellCommand:   "echo {{ name }}",
+			values:         map[string]string{"name": "John"},
+			expectedOutput: "echo John",
+			expectError:    false,
+		},
+		{
+			name:           "Multiple replacements",
+			shellCommand:   "echo {{ first }} {{ last }}",
+			values:         map[string]string{"first": "John", "last": "Doe"},
+			expectedOutput: "echo John Doe",
+			expectError:    false,
+		},
+		{
+			name:           "Replacement with spaces in template",
+			shellCommand:   "echo {{  name  }}",
+			values:         map[string]string{"name": "John"},
+			expectedOutput: "echo John",
+			expectError:    false,
+		},
+		{
+			name:           "Missing argument",
+			shellCommand:   "echo {{ missing }}",
+			values:         map[string]string{},
+			expectedOutput: "",
+			expectError:    true,
+			errorContains:  "required arg not provided: missing",
+		},
+		{
+			name:           "No replacements needed",
+			shellCommand:   "echo hello",
+			values:         map[string]string{},
+			expectedOutput: "echo hello",
+			expectError:    false,
+		},
+		{
+			name:           "Multiple same argument",
+			shellCommand:   "echo {{ name }} says hello {{ name }}",
+			values:         map[string]string{"name": "Alice"},
+			expectedOutput: "echo Alice says hello Alice",
+			expectError:    false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			output, err := parseCommandForReplacements(tt.shellCommand, tt.values)
+
+			if tt.expectError {
+				assert.NotNil(t, err, "Expected error but got none")
+				if tt.errorContains != "" {
+					assert.Contains(t, err.Error(), tt.errorContains)
+				}
+			} else {
+				assert.Nil(t, err, "Expected no error but got: %v", err)
+				assert.Equal(t, tt.expectedOutput, output)
+			}
+		})
+	}
+}
+
+func TestArgumentChoicesValidation(t *testing.T) {
+	tests := []struct {
+		name        string
+		action      config.Action
+		values      map[string]string
+		expectError bool
+		description string
+	}{
+		{
+			name: "Valid choice",
+			action: config.Action{
+				Title: "Test choices",
+				Shell: "echo {{ option }}",
+				Arguments: []config.ActionArgument{
+					{
+						Name: "option",
+						Type: "ascii",
+						Choices: []config.ActionArgumentChoice{
+							{Value: "option1", Title: "Option 1"},
+							{Value: "option2", Title: "Option 2"},
+						},
+					},
+				},
+			},
+			values:      map[string]string{"option": "option1"},
+			expectError: false,
+			description: "Should accept valid choice",
+		},
+		{
+			name: "Invalid choice",
+			action: config.Action{
+				Title: "Test choices",
+				Shell: "echo {{ option }}",
+				Arguments: []config.ActionArgument{
+					{
+						Name: "option",
+						Type: "ascii",
+						Choices: []config.ActionArgumentChoice{
+							{Value: "option1", Title: "Option 1"},
+							{Value: "option2", Title: "Option 2"},
+						},
+					},
+				},
+			},
+			values:      map[string]string{"option": "invalid_option"},
+			expectError: true,
+			description: "Should reject invalid choice",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			_, err := parseActionArguments(tt.values, &tt.action, "")
+
+			if tt.expectError {
+				assert.NotNil(t, err, tt.description)
+				assert.Contains(t, err.Error(), "predefined choices")
+			} else {
+				assert.Nil(t, err, tt.description)
+			}
+		})
+	}
+}
+
+func TestTypeSafetyCheckVeryDangerousRawString(t *testing.T) {
+	// This type should allow anything without validation
+	tests := []string{
+		"normal text",
+		"_zomg_ c:/ haxxor ' bobby tables && rm -rf /",
+		"$(rm -rf /)",
+		"; DROP TABLE users; --",
+		"../../../../etc/passwd",
+		"",
+		"unicode: 你好世界",
+		"emojis: 🔥💀☠️",
+	}
+
+	for _, value := range tests {
+		t.Run(fmt.Sprintf("Value: %s", value), func(t *testing.T) {
+			err := TypeSafetyCheck("test", value, "very_dangerous_raw_string")
+			assert.Nil(t, err, "very_dangerous_raw_string should accept any value including: %s", value)
+		})
+	}
+}
+
+func TestParseActionArgumentsWithEntityPrefix(t *testing.T) {
+	action := config.Action{
+		Title: "Test entity prefix",
+		Shell: "echo 'Processing {{ name }} for entity'",
+		Arguments: []config.ActionArgument{
+			{Name: "name", Type: "ascii"},
+		},
+	}
+
+	values := map[string]string{
+		"name": "testuser",
+	}
+
+	// Test with entity prefix
+	output, err := parseActionArguments(values, &action, "entity_123")
+	assert.Nil(t, err)
+	assert.Contains(t, output, "testuser")
+}
+
+func TestComplexRegexPatterns(t *testing.T) {
+	tests := []struct {
+		name     string
+		pattern  string
+		value    string
+		hasError bool
+	}{
+		{
+			name:     "Phone number pattern",
+			pattern:  "regex:^\\+?[1-9]\\d{1,14}$",
+			value:    "+1234567890",
+			hasError: false,
+		},
+		{
+			name:     "Invalid phone number",
+			pattern:  "regex:^\\+?[1-9]\\d{1,14}$",
+			value:    "123abc",
+			hasError: true,
+		},
+		{
+			name:     "Semantic version pattern",
+			pattern:  "regex:^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)$",
+			value:    "1.2.3",
+			hasError: false,
+		},
+		{
+			name:     "Invalid semantic version",
+			pattern:  "regex:^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)$",
+			value:    "1.2",
+			hasError: true,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			err := typeSafetyCheckRegex("test", tt.value, tt.pattern)
+			if tt.hasError {
+				assert.NotNil(t, err)
+			} else {
+				assert.Nil(t, err)
+			}
+		})
+	}
+}