Kaynağa Gözat

fmt: reduce gocylo

jamesread 1 ay önce
ebeveyn
işleme
65a2ac54a8

+ 32 - 13
service/internal/config/sanitize.go

@@ -159,24 +159,43 @@ func (action *Action) sanitize(cfg *Config) {
 }
 
 func sanitizeActionExecutionMode(action *Action) {
-	hasShell := action.Shell != ""
-	hasExec := len(action.Exec) > 0
-	hasExecTool := action.ExecTool != nil && action.ExecTool.Name != "" && action.ExecTool.Config != nil
-
-	if hasExecTool && (hasShell || hasExec) {
-		log.Warnf("Action %q has both execTool and shell/exec; using execTool only", action.Title)
-		action.Shell = ""
-		action.Exec = nil
-	}
-	if hasExec && hasShell {
+	prioritizeExecToolOverShellAndExec(action)
+	if len(action.Exec) > 0 && action.Shell != "" {
 		log.Warnf("Action %q has both shell and exec; using exec only", action.Title)
 		action.Shell = ""
 	}
+	clearIncompleteExecTool(action)
+}
 
-	if action.ExecTool != nil && (action.ExecTool.Name == "" || action.ExecTool.Config == nil) {
-		log.Warnf("Action %q has execTool with missing name or config; clearing execTool", action.Title)
-		action.ExecTool = nil
+func prioritizeExecToolOverShellAndExec(action *Action) {
+	if !execToolIsFullyConfigured(action) {
+		return
+	}
+	if action.Shell == "" && len(action.Exec) == 0 {
+		return
+	}
+	log.Warnf("Action %q has both execTool and shell/exec; using execTool only", action.Title)
+	action.Shell = ""
+	action.Exec = nil
+}
+
+func execToolIsFullyConfigured(action *Action) bool {
+	t := action.ExecTool
+	if t == nil {
+		return false
+	}
+	if t.Name == "" {
+		return false
+	}
+	return t.Config != nil
+}
+
+func clearIncompleteExecTool(action *Action) {
+	if action.ExecTool == nil || (action.ExecTool.Name != "" && action.ExecTool.Config != nil) {
+		return
 	}
+	log.Warnf("Action %q has execTool with missing name or config; clearing execTool", action.Title)
+	action.ExecTool = nil
 }
 
 func (cfg *Config) sanitizeAuthRequireGuestsToLogin() {

+ 19 - 0
service/internal/executor/exec_tool_helper_kill.go

@@ -0,0 +1,19 @@
+package executor
+
+import (
+	"context"
+	"os/exec"
+)
+
+func runExecToolHelperKillCommand(attrs map[string]string) {
+	helper, killID := "", ""
+	if attrs != nil {
+		helper = attrs["helper"]
+		killID = attrs["kill_id"]
+	}
+	if helper == "" || killID == "" {
+		return
+	}
+	killCmd := exec.CommandContext(context.Background(), "olivetin-"+helper, "kill", killID)
+	_ = killCmd.Run()
+}

+ 95 - 59
service/internal/executor/executor.go

@@ -19,6 +19,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"maps"
 	"os"
 	"os/exec"
 	"path"
@@ -734,29 +735,44 @@ func hasExecTool(req *ExecutionRequest) bool {
 }
 
 func handleExecToolBranch(req *ExecutionRequest) bool {
-	if err := validateArguments(req.Arguments, req.Binding.Action); err != nil {
+	if err := setupExecToolFields(req); err != nil {
 		return fail(req, err)
 	}
+	return true
+}
 
-	cfg := req.Binding.Action.ExecTool.Config
-	if cfg == nil {
-		return fail(req, fmt.Errorf("execTool config is nil"))
+func setupExecToolFields(req *ExecutionRequest) error {
+	cfg, err := validatedExecToolConfig(req)
+	if err != nil {
+		return err
 	}
-
 	applied, err := tpl.ApplyTemplatesToExecToolConfig(cfg, req.Binding.Entity, req.Arguments)
 	if err != nil {
-		return fail(req, err)
+		return err
 	}
-
 	configJSON, err := json.Marshal(applied)
 	if err != nil {
-		return fail(req, err)
+		return err
+	}
+	assignExecToolRequest(req, configJSON)
+	return nil
+}
+
+func validatedExecToolConfig(req *ExecutionRequest) (map[string]any, error) {
+	if err := validateArguments(req.Arguments, req.Binding.Action); err != nil {
+		return nil, err
+	}
+	cfg := req.Binding.Action.ExecTool.Config
+	if cfg == nil {
+		return nil, fmt.Errorf("execTool config is nil")
 	}
+	return cfg, nil
+}
 
+func assignExecToolRequest(req *ExecutionRequest, configJSON []byte) {
 	req.useExecTool = true
 	req.execToolName = req.Binding.Action.ExecTool.Name
 	req.execToolConfig = configJSON
-	return true
 }
 
 func fail(req *ExecutionRequest, err error) bool {
@@ -841,6 +857,26 @@ func appendErrorToStderr(err error, logEntry *InternalLogEntry) {
 	}
 }
 
+func finishExecutionLog(req *ExecutionRequest, ctx *timeoutContext, streamer *OutputStreamer, runerr, waiterr error, exitCode int32) {
+	req.logEntry.ExitCode = exitCode
+	req.logEntry.Output = streamer.String()
+	appendErrorToStderr(runerr, req.logEntry)
+	appendErrorToStderr(waiterr, req.logEntry)
+	applyActionTimeoutToLog(req, ctx)
+	req.logEntry.DatetimeFinished = time.Now()
+}
+
+func applyActionTimeoutToLog(req *ExecutionRequest, ctx *timeoutContext) {
+	if ctx.Err() != context.DeadlineExceeded {
+		return
+	}
+	log.WithFields(log.Fields{
+		"actionTitle": req.logEntry.ActionTitle,
+	}).Warnf("Action timed out")
+	req.logEntry.TimedOut = true
+	req.logEntry.Output += "OliveTin::timeout - this action timed out after " + fmt.Sprintf("%v", req.Binding.Action.Timeout) + " seconds. If you need more time for this action, set a longer timeout. See https://docs.olivetin.app/action_customization/timeouts.html for more help."
+}
+
 type OutputStreamer struct {
 	Req    *ExecutionRequest
 	output bytes.Buffer
@@ -873,33 +909,27 @@ func (m *MetadataStreamFilter) Write(p []byte) (n int, err error) {
 	}
 	m.buf = append(m.buf, p...)
 	if len(m.buf) > metadataMaxFirstLineLen {
-		m.done = true
-		_, _ = m.w.Write(m.buf)
-		m.buf = nil
-		return len(p), nil
+		return m.finishMetadataScanAsPlaintext(p)
 	}
 	idx := bytes.IndexByte(m.buf, '\n')
 	if idx < 0 {
 		return len(p), nil
 	}
+	return m.finishMetadataScanAtNewline(idx, p)
+}
+
+func (m *MetadataStreamFilter) finishMetadataScanAsPlaintext(p []byte) (n int, err error) {
+	m.done = true
+	_, _ = m.w.Write(m.buf)
+	m.buf = nil
+	return len(p), nil
+}
+
+func (m *MetadataStreamFilter) finishMetadataScanAtNewline(idx int, p []byte) (n int, err error) {
 	line := m.buf[:idx]
 	m.buf = m.buf[idx+1:]
 	m.done = true
-	if bytes.HasPrefix(line, []byte("OLIVETIN_METADATA ")) {
-		jsonPart := line[len("OLIVETIN_METADATA "):]
-		var attrs map[string]string
-		if json.Unmarshal(jsonPart, &attrs) == nil && attrs != nil {
-			if m.logEntry.Attributes == nil {
-				m.logEntry.Attributes = make(map[string]string)
-			}
-			for k, v := range attrs {
-				m.logEntry.Attributes[k] = v
-			}
-		}
-	} else {
-		_, _ = m.w.Write(line)
-		_, _ = m.w.Write([]byte{'\n'})
-	}
+	m.writeFirstLineAndMaybeMetadata(line)
 	if len(m.buf) > 0 {
 		_, _ = m.w.Write(m.buf)
 		m.buf = nil
@@ -907,6 +937,37 @@ func (m *MetadataStreamFilter) Write(p []byte) (n int, err error) {
 	return len(p), nil
 }
 
+func (m *MetadataStreamFilter) writeFirstLineAndMaybeMetadata(line []byte) {
+	if bytes.HasPrefix(line, []byte("OLIVETIN_METADATA ")) {
+		m.mergeMetadataLine(line[len("OLIVETIN_METADATA "):])
+		return
+	}
+	_, _ = m.w.Write(line)
+	_, _ = m.w.Write([]byte{'\n'})
+}
+
+func (m *MetadataStreamFilter) mergeMetadataLine(jsonPart []byte) {
+	attrs, ok := parseMetadataAttrsJSON(jsonPart)
+	if !ok {
+		return
+	}
+	if m.logEntry.Attributes == nil {
+		m.logEntry.Attributes = make(map[string]string)
+	}
+	maps.Copy(m.logEntry.Attributes, attrs)
+}
+
+func parseMetadataAttrsJSON(jsonPart []byte) (map[string]string, bool) {
+	var attrs map[string]string
+	if err := json.Unmarshal(jsonPart, &attrs); err != nil {
+		return nil, false
+	}
+	if attrs == nil {
+		return nil, false
+	}
+	return attrs, true
+}
+
 func buildEnv(args map[string]string) []string {
 	ret := append(os.Environ(), "OLIVETIN=1")
 
@@ -942,22 +1003,7 @@ func stepExec(req *ExecutionRequest) bool {
 	req.logEntry.Process = cmd.Process
 	ctx.setProcess(cmd.Process)
 	waiterr := cmd.Wait()
-	req.logEntry.ExitCode = int32(cmd.ProcessState.ExitCode())
-	req.logEntry.Output = streamer.String()
-
-	appendErrorToStderr(runerr, req.logEntry)
-	appendErrorToStderr(waiterr, req.logEntry)
-
-	if ctx.Err() == context.DeadlineExceeded {
-		log.WithFields(log.Fields{
-			"actionTitle": req.logEntry.ActionTitle,
-		}).Warnf("Action timed out")
-
-		req.logEntry.TimedOut = true
-		req.logEntry.Output += "OliveTin::timeout - this action timed out after " + fmt.Sprintf("%v", req.Binding.Action.Timeout) + " seconds. If you need more time for this action, set a longer timeout. See https://docs.olivetin.app/action_customization/timeouts.html for more help."
-	}
-
-	req.logEntry.DatetimeFinished = time.Now()
+	finishExecutionLog(req, ctx, streamer, runerr, waiterr, int32(cmd.ProcessState.ExitCode()))
 
 	return true
 }
@@ -973,6 +1019,10 @@ func stepExecTool(req *ExecutionRequest, ctx *timeoutContext, streamer *OutputSt
 	if cmd == nil {
 		return false
 	}
+	return runExecToolCommand(req, ctx, streamer, cmd)
+}
+
+func runExecToolCommand(req *ExecutionRequest, ctx *timeoutContext, streamer *OutputStreamer, cmd *exec.Cmd) bool {
 	stdinPayload := buildExecToolStdinPayload(req)
 	filter := &MetadataStreamFilter{w: streamer, logEntry: req.logEntry}
 	cmd.Stdout = filter
@@ -990,21 +1040,7 @@ func stepExecTool(req *ExecutionRequest, ctx *timeoutContext, streamer *OutputSt
 	_, _ = stdinPipe.Write(stdinPayload)
 	_ = stdinPipe.Close()
 	waiterr := cmd.Wait()
-	req.logEntry.ExitCode = int32(cmd.ProcessState.ExitCode())
-	req.logEntry.Output = streamer.String()
-
-	appendErrorToStderr(runerr, req.logEntry)
-	appendErrorToStderr(waiterr, req.logEntry)
-
-	if ctx.Err() == context.DeadlineExceeded {
-		log.WithFields(log.Fields{
-			"actionTitle": req.logEntry.ActionTitle,
-		}).Warnf("Action timed out")
-		req.logEntry.TimedOut = true
-		req.logEntry.Output += "OliveTin::timeout - this action timed out after " + fmt.Sprintf("%v", req.Binding.Action.Timeout) + " seconds. If you need more time for this action, set a longer timeout. See https://docs.olivetin.app/action_customization/timeouts.html for more help."
-	}
-
-	req.logEntry.DatetimeFinished = time.Now()
+	finishExecutionLog(req, ctx, streamer, runerr, waiterr, int32(cmd.ProcessState.ExitCode()))
 	return true
 }
 

+ 1 - 10
service/internal/executor/executor_unix.go

@@ -13,16 +13,7 @@ func (e *Executor) Kill(execReq *InternalLogEntry) error {
 	if execReq == nil {
 		return nil
 	}
-	helper := ""
-	killID := ""
-	if execReq.Attributes != nil {
-		helper = execReq.Attributes["helper"]
-		killID = execReq.Attributes["kill_id"]
-	}
-	if helper != "" && killID != "" {
-		killCmd := exec.CommandContext(context.Background(), "olivetin-"+helper, "kill", killID)
-		_ = killCmd.Run()
-	}
+	runExecToolHelperKillCommand(execReq.Attributes)
 	if execReq.Process != nil {
 		return syscall.Kill(-execReq.Process.Pid, syscall.SIGKILL)
 	}

+ 1 - 10
service/internal/executor/executor_windows.go

@@ -13,16 +13,7 @@ func (e *Executor) Kill(execReq *InternalLogEntry) error {
 	if execReq == nil {
 		return nil
 	}
-	helper := ""
-	killID := ""
-	if execReq.Attributes != nil {
-		helper = execReq.Attributes["helper"]
-		killID = execReq.Attributes["kill_id"]
-	}
-	if helper != "" && killID != "" {
-		killCmd := exec.CommandContext(context.Background(), "olivetin-"+helper, "kill", killID)
-		_ = killCmd.Run()
-	}
+	runExecToolHelperKillCommand(execReq.Attributes)
 	if execReq.Process != nil {
 		return execReq.Process.Kill()
 	}

+ 11 - 6
service/internal/executor/timeout_context.go

@@ -57,11 +57,16 @@ func (tc *timeoutContext) setProcess(process *os.Process) {
 	}
 	tc.processMu.Unlock()
 
-	if tc.Context.Err() == context.DeadlineExceeded && process != nil {
-		if tc.logEntry != nil {
-			_ = tc.executor.Kill(tc.logEntry)
-		} else {
-			_ = tc.executor.Kill(&InternalLogEntry{Process: process})
-		}
+	tc.killProcessIfAlreadyTimedOut(process)
+}
+
+func (tc *timeoutContext) killProcessIfAlreadyTimedOut(process *os.Process) {
+	if tc.Context.Err() != context.DeadlineExceeded || process == nil {
+		return
+	}
+	if tc.logEntry != nil {
+		_ = tc.executor.Kill(tc.logEntry)
+		return
 	}
+	_ = tc.executor.Kill(&InternalLogEntry{Process: process})
 }

+ 46 - 33
service/internal/tpl/templates.go

@@ -219,42 +219,55 @@ func applyTemplatesToValue(v any, ent *entities.Entity, args map[string]string,
 	if depth > maxExecToolConfigDepth {
 		return nil, fmt.Errorf("execTool config nested too deeply")
 	}
-	switch val := v.(type) {
-	case string:
-		return ParseTemplateWithActionContext(val, ent, args)
-	case map[string]any:
-		result := make(map[string]any, len(val))
-		for k, elem := range val {
-			transformed, err := applyTemplatesToValue(elem, ent, args, depth+1)
-			if err != nil {
-				return nil, err
-			}
-			result[k] = transformed
-		}
-		return result, nil
-	case []any:
-		result := make([]any, len(val))
-		for i, elem := range val {
-			transformed, err := applyTemplatesToValue(elem, ent, args, depth+1)
-			if err != nil {
-				return nil, err
-			}
-			result[i] = transformed
+	if s, ok := v.(string); ok {
+		return ParseTemplateWithActionContext(s, ent, args)
+	}
+	return applyTemplatesToComposite(v, ent, args, depth)
+}
+
+func applyTemplatesToComposite(v any, ent *entities.Entity, args map[string]string, depth int) (any, error) {
+	if m, ok := v.(map[string]any); ok {
+		return applyTemplatesToMap(m, ent, args, depth)
+	}
+	if items, ok := v.([]any); ok {
+		return applyTemplatesToSliceAny(items, ent, args, depth)
+	}
+	if strItems, ok := v.([]string); ok {
+		return applyTemplatesToSliceAny(stringSliceAsAnySlice(strItems), ent, args, depth)
+	}
+	return v, nil
+}
+
+func stringSliceAsAnySlice(val []string) []any {
+	out := make([]any, len(val))
+	for i, s := range val {
+		out[i] = s
+	}
+	return out
+}
+
+func applyTemplatesToMap(val map[string]any, ent *entities.Entity, args map[string]string, depth int) (map[string]any, error) {
+	result := make(map[string]any, len(val))
+	for k, elem := range val {
+		transformed, err := applyTemplatesToValue(elem, ent, args, depth+1)
+		if err != nil {
+			return nil, err
 		}
-		return result, nil
-	case []string:
-		result := make([]any, len(val))
-		for i, elem := range val {
-			transformed, err := applyTemplatesToValue(elem, ent, args, depth+1)
-			if err != nil {
-				return nil, err
-			}
-			result[i] = transformed
+		result[k] = transformed
+	}
+	return result, nil
+}
+
+func applyTemplatesToSliceAny(val []any, ent *entities.Entity, args map[string]string, depth int) ([]any, error) {
+	result := make([]any, len(val))
+	for i, elem := range val {
+		transformed, err := applyTemplatesToValue(elem, ent, args, depth+1)
+		if err != nil {
+			return nil, err
 		}
-		return result, nil
-	default:
-		return v, nil
+		result[i] = transformed
 	}
+	return result, nil
 }
 
 func ParseTemplateOfActionBeforeExec(source string, ent *entities.Entity) string {

+ 34 - 29
service/internal/tpl/templates_test.go

@@ -9,16 +9,40 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
+type parseTemplateJsonTestCase struct {
+	name           string
+	source         string
+	ent            *entities.Entity
+	args           map[string]string
+	expectedOutput string
+	expectError    bool
+	checkJsonOnly  bool
+}
+
+func assertParseTemplateJsonCase(t *testing.T, tt parseTemplateJsonTestCase, output string, err error) {
+	if tt.expectError {
+		assert.Error(t, err)
+		return
+	}
+	assert.NoError(t, err)
+	if tt.checkJsonOnly {
+		prefix := strings.TrimSuffix(tt.expectedOutput, " ")
+		assert.True(t, strings.HasPrefix(output, prefix), "output %q should start with %q", output, prefix)
+		jsonPart := strings.TrimSpace(strings.TrimPrefix(output, prefix))
+		var decoded map[string]string
+		err := json.Unmarshal([]byte(jsonPart), &decoded)
+		assert.NoError(t, err)
+		for k, v := range tt.args {
+			assert.Equal(t, v, decoded[k], "decoded JSON should contain %s=%s", k, v)
+		}
+		assert.Len(t, decoded, len(tt.args))
+		return
+	}
+	assert.Equal(t, tt.expectedOutput, output)
+}
+
 func TestParseTemplateWithActionContext_Json(t *testing.T) {
-	tests := []struct {
-		name           string
-		source         string
-		ent            *entities.Entity
-		args           map[string]string
-		expectedOutput string
-		expectError    bool
-		checkJsonOnly  bool
-	}{
+	tests := []parseTemplateJsonTestCase{
 		{
 			name:           "Arguments piped to Json",
 			source:         `echo {{ .Arguments | Json }}`,
@@ -56,26 +80,7 @@ func TestParseTemplateWithActionContext_Json(t *testing.T) {
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
 			output, err := ParseTemplateWithActionContext(tt.source, tt.ent, tt.args)
-			if tt.expectError {
-				assert.Error(t, err)
-				return
-			}
-			assert.NoError(t, err)
-			if tt.checkJsonOnly {
-				prefix := strings.TrimSuffix(tt.expectedOutput, " ")
-				assert.True(t, strings.HasPrefix(output, prefix), "output %q should start with %q", output, prefix)
-				jsonPart := strings.TrimPrefix(output, prefix)
-				jsonPart = strings.TrimSpace(jsonPart)
-				var decoded map[string]string
-				err := json.Unmarshal([]byte(jsonPart), &decoded)
-				assert.NoError(t, err)
-				for k, v := range tt.args {
-					assert.Equal(t, v, decoded[k], "decoded JSON should contain %s=%s", k, v)
-				}
-				assert.Len(t, decoded, len(tt.args))
-			} else {
-				assert.Equal(t, tt.expectedOutput, output)
-			}
+			assertParseTemplateJsonCase(t, tt, output, err)
 		})
 	}
 }