ソースを参照

chore: reduce cyclo complexity in service

jamesread 8 ヶ月 前
コミット
c3d5da1981

+ 53 - 50
service/internal/acl/acl.go

@@ -196,61 +196,64 @@ func getHeaderKeyOrEmpty(headers http.Header, key string) string {
 
 // UserFromContext tries to find a user from a Connect RPC context
 func UserFromContext[T any](ctx context.Context, req *connect.Request[T], cfg *config.Config) *AuthenticatedUser {
-	var ret *AuthenticatedUser
-
-	if req != nil {
-		ret = &AuthenticatedUser{}
-		// Only trust headers if explicitly configured
-		if cfg.AuthHttpHeaderUsername != "" {
-			ret.Username = getHeaderKeyOrEmpty(req.Header(), cfg.AuthHttpHeaderUsername)
-		}
-
-		if cfg.AuthHttpHeaderUserGroup != "" {
-			ret.UsergroupLine = getHeaderKeyOrEmpty(req.Header(), cfg.AuthHttpHeaderUserGroup)
-		}
-		// Optional provider header; otherwise infer below
-		prov := getHeaderKeyOrEmpty(req.Header(), "provider")
-		if prov != "" {
-			ret.Provider = prov
-		}
-
-		// If no username from headers, fall back to local session cookie
-		if ret.Username == "" {
-			// Build a minimal http.Request to parse cookies from headers
-			dummy := &http.Request{Header: req.Header()}
-			if c, err := dummy.Cookie("olivetin-sid-local"); err == nil && c != nil && c.Value != "" {
-				if sess := auth.GetUserSession("local", c.Value); sess != nil {
-					if u := cfg.FindUserByUsername(sess.Username); u != nil {
-						ret.Username = u.Username
-						ret.UsergroupLine = u.Usergroup
-						ret.Provider = "local"
-						ret.SID = c.Value
-					} else {
-						log.WithFields(log.Fields{"username": sess.Username}).Warn("UserFromContext: local session user not in config")
-					}
-				} else {
-					log.WithFields(log.Fields{"sid": c.Value, "provider": "local"}).Warn("UserFromContext: stale local session")
-				}
-			}
-		}
-
-		if ret.Username != "" {
-			buildUserAcls(cfg, ret)
-		}
+	user := userFromHeaders(req, cfg)
+	if user.Username == "" {
+		user = userFromLocalSession(req, cfg, user)
 	}
-
-	if ret == nil || ret.Username == "" {
-		ret = UserGuest(cfg)
+	if user.Username == "" {
+		user = *UserGuest(cfg)
+	} else {
+		buildUserAcls(cfg, &user)
 	}
-
 	log.WithFields(log.Fields{
-		"username":      ret.Username,
-		"usergroupLine": ret.UsergroupLine,
-		"provider":      ret.Provider,
-		"acls":          ret.Acls,
+		"username":      user.Username,
+		"usergroupLine": user.UsergroupLine,
+		"provider":      user.Provider,
+		"acls":          user.Acls,
 	}).Debugf("UserFromContext")
+	return &user
+}
 
-	return ret
+func userFromHeaders[T any](req *connect.Request[T], cfg *config.Config) AuthenticatedUser {
+	var u AuthenticatedUser
+	if req == nil {
+		return u
+	}
+	if cfg.AuthHttpHeaderUsername != "" {
+		u.Username = getHeaderKeyOrEmpty(req.Header(), cfg.AuthHttpHeaderUsername)
+	}
+	if cfg.AuthHttpHeaderUserGroup != "" {
+		u.UsergroupLine = getHeaderKeyOrEmpty(req.Header(), cfg.AuthHttpHeaderUserGroup)
+	}
+	if prov := getHeaderKeyOrEmpty(req.Header(), "provider"); prov != "" {
+		u.Provider = prov
+	}
+	return u
+}
+
+func userFromLocalSession[T any](req *connect.Request[T], cfg *config.Config, u AuthenticatedUser) AuthenticatedUser {
+	if req == nil || u.Username != "" {
+		return u
+	}
+	dummy := &http.Request{Header: req.Header()}
+	c, err := dummy.Cookie("olivetin-sid-local")
+	if err != nil || c == nil || c.Value == "" {
+		return u
+	}
+	sess := auth.GetUserSession("local", c.Value)
+	if sess == nil {
+		log.WithFields(log.Fields{"sid": c.Value, "provider": "local"}).Warn("UserFromContext: stale local session")
+		return u
+	}
+	if cfgUser := cfg.FindUserByUsername(sess.Username); cfgUser != nil {
+		u.Username = cfgUser.Username
+		u.UsergroupLine = cfgUser.Usergroup
+		u.Provider = "local"
+		u.SID = c.Value
+		return u
+	}
+	log.WithFields(log.Fields{"username": sess.Username}).Warn("UserFromContext: local session user not in config")
+	return u
 }
 
 func UserGuest(cfg *config.Config) *AuthenticatedUser {

+ 74 - 85
service/internal/api/api.go

@@ -452,31 +452,14 @@ func (api *oliveTinAPI) GetLogs(ctx ctx.Context, req *connect.Request[apiv1.GetL
 		return nil, err
 	}
 
-	ret := &apiv1.GetLogsResponse{}
-
-	logEntries, pagingResult := api.executor.GetLogTrackingIds(req.Msg.StartOffset, api.cfg.LogHistoryPageSize)
-
-	for _, logEntry := range logEntries {
-		// Skip if binding is nil or action is nil
-		if logEntry.Binding == nil || logEntry.Binding.Action == nil {
-			continue
-		}
-
-		action := logEntry.Binding.Action
-
-		if acl.IsAllowedLogs(api.cfg, user, action) {
-			pbLogEntry := api.internalLogEntryToPb(logEntry, user)
-
-			ret.Logs = append(ret.Logs, pbLogEntry)
-		}
-	}
-
-	ret.CountRemaining = pagingResult.CountRemaining
-	ret.PageSize = pagingResult.PageSize
-	ret.TotalCount = pagingResult.TotalCount
-	ret.StartOffset = pagingResult.StartOffset
-
-	return connect.NewResponse(ret), nil
+    ret := &apiv1.GetLogsResponse{}
+    logEntries, paging := api.executor.GetLogTrackingIds(req.Msg.StartOffset, api.cfg.LogHistoryPageSize)
+    ret.Logs = api.pbLogsFiltered(logEntries, user)
+    ret.CountRemaining = paging.CountRemaining
+    ret.PageSize = paging.PageSize
+    ret.TotalCount = paging.TotalCount
+    ret.StartOffset = paging.StartOffset
+    return connect.NewResponse(ret), nil
 }
 
 func (api *oliveTinAPI) GetActionLogs(ctx ctx.Context, req *connect.Request[apiv1.GetActionLogsRequest]) (*connect.Response[apiv1.GetActionLogsResponse], error) {
@@ -486,66 +469,72 @@ func (api *oliveTinAPI) GetActionLogs(ctx ctx.Context, req *connect.Request[apiv
 		return nil, err
 	}
 
-	ret := &apiv1.GetActionLogsResponse{}
-
-	logs := api.executor.GetLogsByActionId(req.Msg.ActionId)
-
-	// Apply ACL filtering
-	filteredLogs := make([]*executor.InternalLogEntry, 0)
-	for _, logEntry := range logs {
-		// Skip if binding is nil or action is nil
-		if logEntry.Binding == nil || logEntry.Binding.Action == nil {
-			continue
-		}
-
-		action := logEntry.Binding.Action
-		if acl.IsAllowedLogs(api.cfg, user, action) {
-			filteredLogs = append(filteredLogs, logEntry)
-		}
-	}
-
-	// Pagination
-	totalCount := int64(len(filteredLogs))
-	pageSize := api.cfg.LogHistoryPageSize
-	startOffset := req.Msg.StartOffset
-
-	// Validate and clamp offset to prevent out-of-bounds access
-	if startOffset < 0 {
-		startOffset = 0
-	}
-
-	// If offset is beyond available data, return empty result with correct metadata
-	if startOffset >= totalCount {
-		ret.CountRemaining = 0
-		ret.PageSize = pageSize
-		ret.TotalCount = totalCount
-		ret.StartOffset = startOffset
-		return connect.NewResponse(ret), nil
-	}
-
-	startIdx := startOffset
-	endIdx := startOffset + pageSize
-	if endIdx > totalCount {
-		endIdx = totalCount
-	}
-
-	logEntries := filteredLogs[startIdx:endIdx]
-	countRemaining := totalCount - endIdx
-	if countRemaining < 0 {
-		countRemaining = 0
-	}
-
-	for _, logEntry := range logEntries {
-		pbLogEntry := api.internalLogEntryToPb(logEntry, user)
-		ret.Logs = append(ret.Logs, pbLogEntry)
-	}
-
-	ret.CountRemaining = countRemaining
-	ret.PageSize = pageSize
-	ret.TotalCount = totalCount
-	ret.StartOffset = startOffset
-
-	return connect.NewResponse(ret), nil
+    ret := &apiv1.GetActionLogsResponse{}
+    filtered := api.filterLogsByACL(api.executor.GetLogsByActionId(req.Msg.ActionId), user)
+    page := paginate(int64(len(filtered)), api.cfg.LogHistoryPageSize, req.Msg.StartOffset)
+    if page.empty {
+        ret.CountRemaining = 0
+        ret.PageSize = page.size
+        ret.TotalCount = page.total
+        ret.StartOffset = page.start
+        return connect.NewResponse(ret), nil
+    }
+    for _, le := range filtered[page.start:page.end] {
+        ret.Logs = append(ret.Logs, api.internalLogEntryToPb(le, user))
+    }
+    ret.CountRemaining = page.total - page.end
+    ret.PageSize = page.size
+    ret.TotalCount = page.total
+    ret.StartOffset = page.start
+    return connect.NewResponse(ret), nil
+}
+
+func (api *oliveTinAPI) pbLogsFiltered(entries []*executor.InternalLogEntry, user *acl.AuthenticatedUser) []*apiv1.LogEntry {
+    out := make([]*apiv1.LogEntry, 0, len(entries))
+    for _, e := range entries {
+        if e == nil || e.Binding == nil || e.Binding.Action == nil {
+            continue
+        }
+        if acl.IsAllowedLogs(api.cfg, user, e.Binding.Action) {
+            out = append(out, api.internalLogEntryToPb(e, user))
+        }
+    }
+    return out
+}
+
+func (api *oliveTinAPI) filterLogsByACL(entries []*executor.InternalLogEntry, user *acl.AuthenticatedUser) []*executor.InternalLogEntry {
+    filtered := make([]*executor.InternalLogEntry, 0, len(entries))
+    for _, e := range entries {
+        if e == nil || e.Binding == nil || e.Binding.Action == nil {
+            continue
+        }
+        if acl.IsAllowedLogs(api.cfg, user, e.Binding.Action) {
+            filtered = append(filtered, e)
+        }
+    }
+    return filtered
+}
+
+type pageInfo struct {
+    total int64
+    size  int64
+    start int64
+    end   int64
+    empty bool
+}
+
+func paginate(total int64, size int64, start int64) pageInfo {
+    if start < 0 {
+        start = 0
+    }
+    if start >= total {
+        return pageInfo{total: total, size: size, start: start, end: start, empty: true}
+    }
+    end := start + size
+    if end > total {
+        end = total
+    }
+    return pageInfo{total: total, size: size, start: start, end: end, empty: false}
 }
 
 /*

+ 33 - 35
service/internal/auth/sessions.go

@@ -82,43 +82,41 @@ func GetUserSession(provider string, sid string) *UserSession {
 
 // LoadUserSessions loads sessions from disk
 func LoadUserSessions(cfg *config.Config) {
-	sessionStorageMutex.Lock()
-	defer sessionStorageMutex.Unlock()
-
-	data, err := os.ReadFile(cfg.GetDir() + "/sessions.yaml")
-	if err != nil {
-		logrus.WithError(err).Warn("Failed to read sessions.yaml file")
-		// Initialize empty session storage if file doesn't exist
-		if sessionStorage == nil {
-			sessionStorage = &SessionStorage{
-				Providers: make(map[string]*SessionProvider),
-			}
-		}
-		return
-	}
-
-	err = yaml.Unmarshal(data, &sessionStorage)
-	if err != nil {
-		logrus.WithError(err).Error("Failed to unmarshal sessions.yaml")
-		// Initialize empty session storage if unmarshal fails
-		if sessionStorage == nil {
-			sessionStorage = &SessionStorage{
-				Providers: make(map[string]*SessionProvider),
-			}
-		}
-		return
-	}
+    sessionStorageMutex.Lock()
+    defer sessionStorageMutex.Unlock()
+
+    data, err := os.ReadFile(cfg.GetDir() + "/sessions.yaml")
+    if err != nil {
+        logrus.WithError(err).Warn("Failed to read sessions.yaml file")
+        ensureEmptySessionStorage()
+        return
+    }
+
+    if err := yaml.Unmarshal(data, &sessionStorage); err != nil {
+        logrus.WithError(err).Error("Failed to unmarshal sessions.yaml")
+        ensureEmptySessionStorage()
+        return
+    }
+
+    ensureSessionStorageInitialized()
+}
 
-	// Ensure sessionStorage and Providers are properly initialized
-	if sessionStorage == nil {
-		sessionStorage = &SessionStorage{
-			Providers: make(map[string]*SessionProvider),
-		}
-	}
+func ensureEmptySessionStorage() {
+    if sessionStorage == nil {
+        sessionStorage = &SessionStorage{Providers: make(map[string]*SessionProvider)}
+    }
+    if sessionStorage.Providers == nil {
+        sessionStorage.Providers = make(map[string]*SessionProvider)
+    }
+}
 
-	if sessionStorage.Providers == nil {
-		sessionStorage.Providers = make(map[string]*SessionProvider)
-	}
+func ensureSessionStorageInitialized() {
+    if sessionStorage == nil {
+        sessionStorage = &SessionStorage{Providers: make(map[string]*SessionProvider)}
+    }
+    if sessionStorage.Providers == nil {
+        sessionStorage.Providers = make(map[string]*SessionProvider)
+    }
 }
 
 func saveUserSessions(cfg *config.Config) {

+ 224 - 209
service/internal/config/config_reloader.go

@@ -46,62 +46,88 @@ func AppendSourceWithIncludes(cfg *Config, k *koanf.Koanf, configPath string) {
 }
 
 func AppendSource(cfg *Config, k *koanf.Koanf, configPath string) {
-	log.Infof("Appending cfg source: %s", configPath)
+    log.Infof("Appending cfg source: %s", configPath)
 
-	// Unmarshal config - koanf will handle mapstructure tags automatically
-	err := k.Unmarshal(".", cfg)
-	if err != nil {
-		log.Errorf("Error unmarshalling config: %v", err)
-		return
-	}
+    if !unmarshalRoot(k, cfg) {
+        return
+    }
 
-	// Fallback for complex nested structures that might not unmarshal correctly
-	// Only attempt manual unmarshaling if the automatic approach didn't populate the fields
-	if len(cfg.Actions) == 0 && k.Exists("actions") {
-		var actions []*Action
-		if err := k.Unmarshal("actions", &actions); err == nil {
-			cfg.Actions = actions
-			log.Debugf("Manually loaded %d actions", len(actions))
-		}
-	}
+    loadCollectionsFallbacks(k, cfg)
 
-	if len(cfg.Dashboards) == 0 && k.Exists("dashboards") {
-		var dashboards []*DashboardComponent
-		if err := k.Unmarshal("dashboards", &dashboards); err == nil {
-			cfg.Dashboards = dashboards
-			log.Debugf("Manually loaded %d dashboards", len(dashboards))
-		}
-	}
+    applyConfigOverrides(k, cfg)
 
-	if len(cfg.Entities) == 0 && k.Exists("entities") {
-		var entities []*EntityFile
-		if err := k.Unmarshal("entities", &entities); err == nil {
-			cfg.Entities = entities
-			log.Debugf("Manually loaded %d entities", len(entities))
-		}
-	}
+    afterLoadFinalize(cfg, configPath)
+}
 
-	if len(cfg.AuthLocalUsers.Users) == 0 && k.Exists("authLocalUsers") {
-		var authLocalUsers AuthLocalUsersConfig
-		if err := k.Unmarshal("authLocalUsers", &authLocalUsers); err == nil {
-			cfg.AuthLocalUsers = authLocalUsers
-			log.Debugf("Manually loaded local auth config")
-		}
-	}
+func unmarshalRoot(k *koanf.Koanf, cfg *Config) bool {
+    if err := k.Unmarshal(".", cfg); err != nil {
+        log.Errorf("Error unmarshalling config: %v", err)
+        return false
+    }
+    return true
+}
 
-	// Map structure tags should handle these automatically, but we keep fallbacks
-	// for fields that might not unmarshal correctly
-	applyConfigOverrides(k, cfg)
+func loadCollectionsFallbacks(k *koanf.Koanf, cfg *Config) {
+    maybeUnmarshalActions(k, cfg)
+    maybeUnmarshalDashboards(k, cfg)
+    maybeUnmarshalEntities(k, cfg)
+    maybeUnmarshalAuthLocalUsers(k, cfg)
+}
 
-	metricConfigReloadedCount.Inc()
-	metricConfigActionCount.Set(float64(len(cfg.Actions)))
+func maybeUnmarshalActions(k *koanf.Koanf, cfg *Config) {
+    if len(cfg.Actions) != 0 || !k.Exists("actions") {
+        return
+    }
+    var actions []*Action
+    if err := k.Unmarshal("actions", &actions); err == nil {
+        cfg.Actions = actions
+        log.Debugf("Manually loaded %d actions", len(actions))
+    }
+}
 
-	cfg.SetDir(filepath.Dir(configPath))
-	cfg.Sanitize()
+func maybeUnmarshalDashboards(k *koanf.Koanf, cfg *Config) {
+    if len(cfg.Dashboards) != 0 || !k.Exists("dashboards") {
+        return
+    }
+    var dashboards []*DashboardComponent
+    if err := k.Unmarshal("dashboards", &dashboards); err == nil {
+        cfg.Dashboards = dashboards
+        log.Debugf("Manually loaded %d dashboards", len(dashboards))
+    }
+}
 
-	for _, l := range listeners {
-		l()
-	}
+func maybeUnmarshalEntities(k *koanf.Koanf, cfg *Config) {
+    if len(cfg.Entities) != 0 || !k.Exists("entities") {
+        return
+    }
+    var entities []*EntityFile
+    if err := k.Unmarshal("entities", &entities); err == nil {
+        cfg.Entities = entities
+        log.Debugf("Manually loaded %d entities", len(entities))
+    }
+}
+
+func maybeUnmarshalAuthLocalUsers(k *koanf.Koanf, cfg *Config) {
+    if len(cfg.AuthLocalUsers.Users) != 0 || !k.Exists("authLocalUsers") {
+        return
+    }
+    var authLocalUsers AuthLocalUsersConfig
+    if err := k.Unmarshal("authLocalUsers", &authLocalUsers); err == nil {
+        cfg.AuthLocalUsers = authLocalUsers
+        log.Debugf("Manually loaded local auth config")
+    }
+}
+
+func afterLoadFinalize(cfg *Config, configPath string) {
+    metricConfigReloadedCount.Inc()
+    metricConfigActionCount.Set(float64(len(cfg.Actions)))
+
+    cfg.SetDir(filepath.Dir(configPath))
+    cfg.Sanitize()
+
+    for _, l := range listeners {
+        l()
+    }
 }
 
 func applyConfigOverrides(k *koanf.Koanf, cfg *Config) {
@@ -131,181 +157,170 @@ func applyConfigOverrides(k *koanf.Koanf, cfg *Config) {
 
 // LoadIncludedConfigs loads configuration files from an include directory and merges them
 func LoadIncludedConfigs(cfg *Config, k *koanf.Koanf, baseConfigPath string) {
-	if cfg.Include == "" {
-		return
-	}
-
-	configDir := filepath.Dir(baseConfigPath)
-	includePath := filepath.Join(configDir, cfg.Include)
-
-	log.Infof("Loading included configs from: %s", includePath)
+    if cfg.Include == "" {
+        return
+    }
 
-	// Check if the include directory exists
-	dirInfo, err := os.Stat(includePath)
-	if err != nil {
-		log.Warnf("Include directory not found: %s", includePath)
-		return
-	}
+    includePath := filepath.Join(filepath.Dir(baseConfigPath), cfg.Include)
+    log.Infof("Loading included configs from: %s", includePath)
 
-	if !dirInfo.IsDir() {
-		log.Warnf("Include path is not a directory: %s", includePath)
-		return
-	}
+    yamlFiles, ok := listYamlFiles(includePath)
+    if !ok || len(yamlFiles) == 0 {
+        return
+    }
 
-	// Read all .yml files from the directory
-	entries, err := os.ReadDir(includePath)
-	if err != nil {
-		log.Errorf("Error reading include directory: %v", err)
-		return
-	}
+    sort.Strings(yamlFiles)
+    for _, filename := range yamlFiles {
+        loadAndMergeIncludedFile(cfg, includePath, filename)
+    }
 
-	// Filter and sort .yml files
-	var yamlFiles []string
-	for _, entry := range entries {
-		if !entry.IsDir() && (strings.HasSuffix(entry.Name(), ".yml") || strings.HasSuffix(entry.Name(), ".yaml")) {
-			yamlFiles = append(yamlFiles, entry.Name())
-		}
-	}
-
-	if len(yamlFiles) == 0 {
-		log.Infof("No YAML files found in include directory: %s", includePath)
-		return
-	}
-
-	// Sort files to ensure deterministic load order
-	sort.Strings(yamlFiles)
-
-	// Load each file and merge into config
-	for _, filename := range yamlFiles {
-		filePath := filepath.Join(includePath, filename)
-		log.Infof("Loading included config file: %s", filePath)
-
-		includeK := koanf.New(".")
-		f := file.Provider(filePath)
-
-		if err := includeK.Load(f, yaml.Parser()); err != nil {
-			log.Errorf("Error loading included config file %s: %v", filePath, err)
-			continue
-		}
-
-		// Unmarshal into a temporary config to process properly
-		tempCfg := &Config{}
-		if err := includeK.Unmarshal(".", tempCfg); err != nil {
-			log.Errorf("Error unmarshalling included config file %s: %v", filePath, err)
-			continue
-		}
-
-		// Apply the same manual loading workarounds as in AppendSource
-		if len(tempCfg.Actions) == 0 && includeK.Exists("actions") {
-			var actions []*Action
-			if err := includeK.Unmarshal("actions", &actions); err == nil {
-				tempCfg.Actions = actions
-				log.Debugf("Manually loaded %d actions from %s", len(actions), filename)
-			}
-		}
-
-		// Merge the temp config into the main config
-		// Later files override earlier ones
-		mergeConfig(cfg, tempCfg)
-
-		log.Infof("Successfully loaded and merged %s", filename)
-	}
+    log.Infof("Finished loading %d included config file(s)", len(yamlFiles))
+    cfg.Sanitize()
+}
 
-	log.Infof("Finished loading %d included config file(s)", len(yamlFiles))
+func listYamlFiles(includePath string) ([]string, bool) {
+    dirInfo, err := os.Stat(includePath)
+    if err != nil {
+        log.Warnf("Include directory not found: %s", includePath)
+        return nil, false
+    }
+    if !dirInfo.IsDir() {
+        log.Warnf("Include path is not a directory: %s", includePath)
+        return nil, false
+    }
+    entries, err := os.ReadDir(includePath)
+    if err != nil {
+        log.Errorf("Error reading include directory: %v", err)
+        return nil, false
+    }
+    var yamlFiles []string
+    for _, entry := range entries {
+        if entry.IsDir() {
+            continue
+        }
+        name := entry.Name()
+        if strings.HasSuffix(name, ".yml") || strings.HasSuffix(name, ".yaml") {
+            yamlFiles = append(yamlFiles, name)
+        }
+    }
+    if len(yamlFiles) == 0 {
+        log.Infof("No YAML files found in include directory: %s", includePath)
+    }
+    return yamlFiles, true
+}
 
-	// Sanitize the merged config
-	cfg.Sanitize()
+func loadAndMergeIncludedFile(cfg *Config, includePath, filename string) {
+    filePath := filepath.Join(includePath, filename)
+    log.Infof("Loading included config file: %s", filePath)
+
+    includeK := koanf.New(".")
+    if err := includeK.Load(file.Provider(filePath), yaml.Parser()); err != nil {
+        log.Errorf("Error loading included config file %s: %v", filePath, err)
+        return
+    }
+
+    tempCfg := &Config{}
+    if err := includeK.Unmarshal(".", tempCfg); err != nil {
+        log.Errorf("Error unmarshalling included config file %s: %v", filePath, err)
+        return
+    }
+    // Fallbacks similar to AppendSource
+    if len(tempCfg.Actions) == 0 && includeK.Exists("actions") {
+        var actions []*Action
+        if err := includeK.Unmarshal("actions", &actions); err == nil {
+            tempCfg.Actions = actions
+            log.Debugf("Manually loaded %d actions from %s", len(actions), filename)
+        }
+    }
+
+    mergeConfig(cfg, tempCfg)
+    log.Infof("Successfully loaded and merged %s", filename)
 }
 
 func mergeConfig(base *Config, overlay *Config) {
-	// Merge Actions - overlay appends to base
-	if len(overlay.Actions) > 0 {
-		base.Actions = append(base.Actions, overlay.Actions...)
-	}
-
-	// Merge Dashboards - overlay appends to base
-	if len(overlay.Dashboards) > 0 {
-		base.Dashboards = append(base.Dashboards, overlay.Dashboards...)
-		log.Debugf("Merged %d dashboards from include", len(overlay.Dashboards))
-	}
-
-	// Merge Entities - overlay appends to base
-	if len(overlay.Entities) > 0 {
-		base.Entities = append(base.Entities, overlay.Entities...)
-		log.Debugf("Merged %d entities from include", len(overlay.Entities))
-	}
-
-	// Merge AccessControlLists - overlay appends to base
-	if len(overlay.AccessControlLists) > 0 {
-		base.AccessControlLists = append(base.AccessControlLists, overlay.AccessControlLists...)
-		log.Debugf("Merged %d access control lists from include", len(overlay.AccessControlLists))
-	}
-
-	// Merge AuthLocalUsers.Users - overlay appends to base
-	if len(overlay.AuthLocalUsers.Users) > 0 {
-		base.AuthLocalUsers.Users = append(base.AuthLocalUsers.Users, overlay.AuthLocalUsers.Users...)
-		log.Debugf("Merged %d local users from include", len(overlay.AuthLocalUsers.Users))
-	}
-
-	// Merge slices by appending
-	if len(overlay.StyleMods) > 0 {
-		base.StyleMods = append(base.StyleMods, overlay.StyleMods...)
-	}
-
-	if len(overlay.AdditionalNavigationLinks) > 0 {
-		base.AdditionalNavigationLinks = append(base.AdditionalNavigationLinks, overlay.AdditionalNavigationLinks...)
-	}
-
-	// Override simple fields (later files win)
-	if overlay.LogLevel != "" {
-		base.LogLevel = overlay.LogLevel
-	}
-	if overlay.PageTitle != "" {
-		base.PageTitle = overlay.PageTitle
-	}
-	if overlay.ShowFooter != base.ShowFooter {
-		base.ShowFooter = overlay.ShowFooter
-	}
-	if overlay.ShowNavigation != base.ShowNavigation {
-		base.ShowNavigation = overlay.ShowNavigation
-	}
-	if overlay.CheckForUpdates != base.CheckForUpdates {
-		base.CheckForUpdates = overlay.CheckForUpdates
-	}
-	if overlay.UseSingleHTTPFrontend != base.UseSingleHTTPFrontend {
-		base.UseSingleHTTPFrontend = overlay.UseSingleHTTPFrontend
-	}
-	if overlay.AuthRequireGuestsToLogin != base.AuthRequireGuestsToLogin {
-		base.AuthRequireGuestsToLogin = overlay.AuthRequireGuestsToLogin
-	}
+    mergeSlices(base, overlay)
+    overrideSimple(base, overlay)
+    overrideNested(base, overlay)
+    overrideStrings(base, overlay)
+}
 
-	// Override nested structs
-	if overlay.DefaultPolicy.ShowDiagnostics != base.DefaultPolicy.ShowDiagnostics {
-		base.DefaultPolicy.ShowDiagnostics = overlay.DefaultPolicy.ShowDiagnostics
-	}
-	if overlay.DefaultPolicy.ShowLogList != base.DefaultPolicy.ShowLogList {
-		base.DefaultPolicy.ShowLogList = overlay.DefaultPolicy.ShowLogList
-	}
+func mergeSlices(base *Config, overlay *Config) {
+    if len(overlay.Actions) > 0 {
+        base.Actions = append(base.Actions, overlay.Actions...)
+    }
+    if len(overlay.Dashboards) > 0 {
+        base.Dashboards = append(base.Dashboards, overlay.Dashboards...)
+        log.Debugf("Merged %d dashboards from include", len(overlay.Dashboards))
+    }
+    if len(overlay.Entities) > 0 {
+        base.Entities = append(base.Entities, overlay.Entities...)
+        log.Debugf("Merged %d entities from include", len(overlay.Entities))
+    }
+    if len(overlay.AccessControlLists) > 0 {
+        base.AccessControlLists = append(base.AccessControlLists, overlay.AccessControlLists...)
+        log.Debugf("Merged %d access control lists from include", len(overlay.AccessControlLists))
+    }
+    if len(overlay.AuthLocalUsers.Users) > 0 {
+        base.AuthLocalUsers.Users = append(base.AuthLocalUsers.Users, overlay.AuthLocalUsers.Users...)
+        log.Debugf("Merged %d local users from include", len(overlay.AuthLocalUsers.Users))
+    }
+    if len(overlay.StyleMods) > 0 {
+        base.StyleMods = append(base.StyleMods, overlay.StyleMods...)
+    }
+    if len(overlay.AdditionalNavigationLinks) > 0 {
+        base.AdditionalNavigationLinks = append(base.AdditionalNavigationLinks, overlay.AdditionalNavigationLinks...)
+    }
+}
 
-	if overlay.Prometheus.Enabled != base.Prometheus.Enabled {
-		base.Prometheus.Enabled = overlay.Prometheus.Enabled
-	}
-	if overlay.Prometheus.DefaultGoMetrics != base.Prometheus.DefaultGoMetrics {
-		base.Prometheus.DefaultGoMetrics = overlay.Prometheus.DefaultGoMetrics
-	}
+func overrideSimple(base *Config, overlay *Config) {
+    if overlay.LogLevel != "" {
+        base.LogLevel = overlay.LogLevel
+    }
+    if overlay.PageTitle != "" {
+        base.PageTitle = overlay.PageTitle
+    }
+    if overlay.ShowFooter != base.ShowFooter {
+        base.ShowFooter = overlay.ShowFooter
+    }
+    if overlay.ShowNavigation != base.ShowNavigation {
+        base.ShowNavigation = overlay.ShowNavigation
+    }
+    if overlay.CheckForUpdates != base.CheckForUpdates {
+        base.CheckForUpdates = overlay.CheckForUpdates
+    }
+    if overlay.UseSingleHTTPFrontend != base.UseSingleHTTPFrontend {
+        base.UseSingleHTTPFrontend = overlay.UseSingleHTTPFrontend
+    }
+    if overlay.AuthRequireGuestsToLogin != base.AuthRequireGuestsToLogin {
+        base.AuthRequireGuestsToLogin = overlay.AuthRequireGuestsToLogin
+    }
+    if overlay.AuthLocalUsers.Enabled {
+        base.AuthLocalUsers.Enabled = overlay.AuthLocalUsers.Enabled
+    }
+}
 
-	// Override AuthLocalUsers.Enabled if set
-	if overlay.AuthLocalUsers.Enabled {
-		base.AuthLocalUsers.Enabled = overlay.AuthLocalUsers.Enabled
-	}
+func overrideNested(base *Config, overlay *Config) {
+    if overlay.DefaultPolicy.ShowDiagnostics != base.DefaultPolicy.ShowDiagnostics {
+        base.DefaultPolicy.ShowDiagnostics = overlay.DefaultPolicy.ShowDiagnostics
+    }
+    if overlay.DefaultPolicy.ShowLogList != base.DefaultPolicy.ShowLogList {
+        base.DefaultPolicy.ShowLogList = overlay.DefaultPolicy.ShowLogList
+    }
+    if overlay.Prometheus.Enabled != base.Prometheus.Enabled {
+        base.Prometheus.Enabled = overlay.Prometheus.Enabled
+    }
+    if overlay.Prometheus.DefaultGoMetrics != base.Prometheus.DefaultGoMetrics {
+        base.Prometheus.DefaultGoMetrics = overlay.Prometheus.DefaultGoMetrics
+    }
+}
 
-	// Override string fields if non-empty
-	overrideString(&base.BannerMessage, overlay.BannerMessage)
-	overrideString(&base.BannerCSS, overlay.BannerCSS)
-	overrideString(&base.LogLevel, overlay.LogLevel)
-	overrideString(&base.PageTitle, overlay.PageTitle)
-	overrideString(&base.SectionNavigationStyle, overlay.SectionNavigationStyle)
-	overrideString(&base.DefaultPopupOnStart, overlay.DefaultPopupOnStart)
+func overrideStrings(base *Config, overlay *Config) {
+    overrideString(&base.BannerMessage, overlay.BannerMessage)
+    overrideString(&base.BannerCSS, overlay.BannerCSS)
+    overrideString(&base.LogLevel, overlay.LogLevel)
+    overrideString(&base.PageTitle, overlay.PageTitle)
+    overrideString(&base.SectionNavigationStyle, overlay.SectionNavigationStyle)
+    overrideString(&base.DefaultPopupOnStart, overlay.DefaultPopupOnStart)
 }
 
 func overrideString(base *string, overlay string) {

+ 55 - 47
service/internal/config/config_reloader_test.go

@@ -90,55 +90,63 @@ var envConfigTests = []struct {
 }
 
 func TestEnvInConfig(t *testing.T) {
-	for _, tt := range envConfigTests {
-		cfg := DefaultConfig()
-
-		if tt.input != "" {
-			os.Setenv("INPUT", tt.input)
-		}
-
-		// Process the YAML content to replace environment variables
-		processedYaml := envRegex.ReplaceAllStringFunc(tt.yaml, func(match string) string {
-			submatches := envRegex.FindStringSubmatch(match)
-			key := submatches[1]
-			val, _ := os.LookupEnv(key)
-			return val
-		})
-
-		k := koanf.New(".")
-		err := k.Load(rawbytes.Provider([]byte(processedYaml)), yaml.Parser())
-		if err != nil {
-			t.Errorf("Error loading YAML: %v", err)
-			continue
-		}
+    for _, tt := range envConfigTests {
+        cfg := DefaultConfig()
+        setIfNotEmpty("INPUT", tt.input)
+        processed := processYamlWithEnv(tt.yaml)
+        k, err := loadKoanf(processed)
+        if err != nil {
+            t.Errorf("Error loading YAML: %v", err)
+            continue
+        }
+        if err := k.Unmarshal(".", cfg); err != nil {
+            t.Errorf("Error unmarshalling config: %v", err)
+            continue
+        }
+        manualAssigns(k, cfg)
+        field := tt.selector(cfg)
+        assert.Equal(t, tt.output, field, "Unmarshaled config field doesn't match expected value: env=\"%s\"", tt.input)
+        os.Unsetenv("INPUT")
+    }
+}
 
-		// Try default unmarshaling
-		err = k.Unmarshal(".", cfg)
-		if err != nil {
-			t.Errorf("Error unmarshalling config: %v", err)
-			continue
-		}
+func setIfNotEmpty(key, val string) {
+    if val != "" {
+        os.Setenv(key, val)
+    }
+}
 
-		// Manual field assignment for testing (since default unmarshaling has issues with field mapping)
-		if k.Exists("PageTitle") {
-			cfg.PageTitle = k.String("PageTitle")
-		}
-		if k.Exists("CheckForUpdates") {
-			cfg.CheckForUpdates = k.Bool("CheckForUpdates")
-		}
-		if k.Exists("LogHistoryPageSize") {
-			cfg.LogHistoryPageSize = k.Int64("LogHistoryPageSize")
-		}
-		if k.Exists("actions") {
-			var actions []*Action
-			if err := k.Unmarshal("actions", &actions); err == nil {
-				cfg.Actions = actions
-			}
-		}
+func processYamlWithEnv(content string) string {
+    return envRegex.ReplaceAllStringFunc(content, func(match string) string {
+        submatches := envRegex.FindStringSubmatch(match)
+        key := submatches[1]
+        val, _ := os.LookupEnv(key)
+        return val
+    })
+}
 
-		field := tt.selector(cfg)
-		assert.Equal(t, tt.output, field, "Unmarshaled config field doesn't match expected value: env=\"%s\"", tt.input)
+func loadKoanf(processed string) (*koanf.Koanf, error) {
+    k := koanf.New(".")
+    if err := k.Load(rawbytes.Provider([]byte(processed)), yaml.Parser()); err != nil {
+        return nil, err
+    }
+    return k, nil
+}
 
-		os.Unsetenv("INPUT")
-	}
+func manualAssigns(k *koanf.Koanf, cfg *Config) {
+    if k.Exists("PageTitle") {
+        cfg.PageTitle = k.String("PageTitle")
+    }
+    if k.Exists("CheckForUpdates") {
+        cfg.CheckForUpdates = k.Bool("CheckForUpdates")
+    }
+    if k.Exists("LogHistoryPageSize") {
+        cfg.LogHistoryPageSize = k.Int64("LogHistoryPageSize")
+    }
+    if k.Exists("actions") {
+        var actions []*Action
+        if err := k.Unmarshal("actions", &actions); err == nil {
+            cfg.Actions = actions
+        }
+    }
 }

+ 25 - 28
service/internal/entities/entities.go

@@ -30,36 +30,33 @@ func AddListener(l func()) {
 }
 
 func SetupEntityFileWatchers(cfg *config.Config) {
-	configDir := cfg.GetDir()
-
-	// Only use var directory if not in integration test mode
-	absConfigDir, _ := filepath.Abs(configDir)
-	if !strings.Contains(absConfigDir, "integration-tests") {
-		configDirVar := filepath.Join(configDir, "var") // for development purposes
-
-		if _, err := os.Stat(configDirVar); err == nil {
-			configDir = configDirVar
-		}
-	}
-
-	for entityIndex := range cfg.Entities { // #337 - iterate by key, not by value
-		ef := cfg.Entities[entityIndex]
-		p := ef.File
-
-		if !filepath.IsAbs(p) {
-			p = filepath.Join(configDir, p)
-
-			log.WithFields(log.Fields{
-				"entityFile": p,
-			}).Debugf("Adding config dir to entity file path")
-		}
+    baseDir := resolveEntitiesBaseDir(cfg.GetDir())
+    for i := range cfg.Entities { // #337 - iterate by key, not by value
+        ef := cfg.Entities[i]
+        watchAndLoadEntity(baseDir, ef)
+    }
+}
 
-		go filehelper.WatchFileWrite(p, func(filename string) {
-			loadEntityFile(p, ef.Name)
-		})
+func resolveEntitiesBaseDir(configDir string) string {
+    absConfigDir, _ := filepath.Abs(configDir)
+    if strings.Contains(absConfigDir, "integration-tests") {
+        return configDir
+    }
+    devVar := filepath.Join(configDir, "var")
+    if _, err := os.Stat(devVar); err == nil {
+        return devVar
+    }
+    return configDir
+}
 
-		loadEntityFile(p, ef.Name)
-	}
+func watchAndLoadEntity(baseDir string, ef *config.EntityFile) {
+    p := ef.File
+    if !filepath.IsAbs(p) {
+        p = filepath.Join(baseDir, p)
+        log.WithFields(log.Fields{"entityFile": p}).Debugf("Adding config dir to entity file path")
+    }
+    go filehelper.WatchFileWrite(p, func(filename string) { loadEntityFile(p, ef.Name) })
+    loadEntityFile(p, ef.Name)
 }
 
 func loadEntityFile(filename string, entityname string) {

+ 39 - 52
service/internal/executor/arguments.go

@@ -43,45 +43,37 @@ func parseCommandForReplacements(shellCommand string, values map[string]string,
 }
 
 func parseActionExec(values map[string]string, action *config.Action, entity *entities.Entity) ([]string, error) {
-	if action == nil {
-		return nil, fmt.Errorf("action is nil")
-	}
-
-	for _, arg := range action.Arguments {
-		argName := arg.Name
-		argValue := values[argName]
-
-		err := typecheckActionArgument(&arg, argValue, action)
-
-		if err != nil {
-			return nil, err
-		}
-
-		log.WithFields(log.Fields{
-			"name":  argName,
-			"value": argValue,
-		}).Debugf("Arg assigned")
-	}
-
-	parsedArgs := make([]string, len(action.Exec))
-	for i, arg := range action.Exec {
-		parsedArg, err := parseCommandForReplacements(arg, values, entity)
-		if err != nil {
-			return nil, err
-		}
-
-		parsedArg = entities.ParseTemplateWithArgs(parsedArg, entity, values)
-		parsedArgs[i] = parsedArg
-	}
-
-	redactedArgs := redactExecArgs(parsedArgs, action.Arguments, values)
+    if action == nil {
+        return nil, fmt.Errorf("action is nil")
+    }
+    if err := validateArguments(values, action); err != nil {
+        return nil, err
+    }
+    parsed := make([]string, len(action.Exec))
+    for i, a := range action.Exec {
+        arg, err := parseCommandForReplacements(a, values, entity)
+        if err != nil {
+            return nil, err
+        }
+        parsed[i] = entities.ParseTemplateWithArgs(arg, entity, values)
+    }
+    logParsedExec(action, parsed, values)
+    return parsed, nil
+}
 
-	log.WithFields(log.Fields{
-		"actionTitle": action.Title,
-		"cmd":         redactedArgs,
-	}).Infof("Action parse args - After (Exec)")
+func validateArguments(values map[string]string, action *config.Action) error {
+    for _, arg := range action.Arguments {
+        if err := typecheckActionArgument(&arg, values[arg.Name], action); err != nil {
+            return err
+        }
+        log.WithFields(log.Fields{"name": arg.Name, "value": values[arg.Name]}).Debugf("Arg assigned")
+    }
+    return nil
+}
 
-	return parsedArgs, nil
+func logParsedExec(action *config.Action, parsed []string, values map[string]string) {
+    redacted := redactExecArgs(parsed, action.Arguments, values)
+    log.WithFields(log.Fields{"actionTitle": action.Title, "cmd": redacted}).Infof("Action parse args - After (Exec)")
 }
 
 func parseActionArguments(values map[string]string, action *config.Action, entity *entities.Entity) (string, error) {
@@ -295,21 +287,16 @@ func typeSafetyCheckUrl(value string) error {
 }
 
 func checkShellArgumentSafety(action *config.Action) error {
-	if action.Shell == "" {
-		return nil
-	}
-
-	unsafeTypes := []string{"url", "email", "raw_string_multiline", "very_dangerous_raw_string"}
-
-	for _, arg := range action.Arguments {
-		for _, unsafeType := range unsafeTypes {
-			if arg.Type == unsafeType {
-				return fmt.Errorf("unsafe argument type '%s' cannot be used with Shell execution. Use 'exec' instead. See https://docs.olivetin.app/action_execution/shellvsexec.html", arg.Type)
-			}
-		}
-	}
-
-	return nil
+    if action.Shell == "" {
+        return nil
+    }
+    unsafe := map[string]struct{}{"url": {}, "email": {}, "raw_string_multiline": {}, "very_dangerous_raw_string": {}}
+    for _, arg := range action.Arguments {
+        if _, bad := unsafe[arg.Type]; bad {
+            return fmt.Errorf("unsafe argument type '%s' cannot be used with Shell execution. Use 'exec' instead. See https://docs.olivetin.app/action_execution/shellvsexec.html", arg.Type)
+        }
+    }
+    return nil
 }
 
 func mangleInvalidArgumentValues(req *ExecutionRequest) {

+ 57 - 49
service/internal/executor/executor.go

@@ -427,51 +427,62 @@ func stepACLCheck(req *ExecutionRequest) bool {
 }
 
 func stepParseArgs(req *ExecutionRequest) bool {
-	var err error
+	ensureArgumentMap(req)
+	injectSystemArgs(req)
+	mangleInvalidArgumentValues(req)
+	if !hasBindingAndAction(req) {
+		return fail(req, fmt.Errorf("cannot parse arguments: Binding or Action is nil"))
+	}
+	if hasExec(req) {
+		return parseExec(req)
+	}
+	if err := checkShellArgumentSafety(req.Binding.Action); err != nil {
+		return fail(req, err)
+	}
+	cmd, err := parseActionArguments(req.Arguments, req.Binding.Action, req.Binding.Entity)
+	if err != nil {
+		return fail(req, err)
+	}
+	req.useDirectExec = false
+	req.finalParsedCommand = cmd
+	return true
+}
 
+func ensureArgumentMap(req *ExecutionRequest) {
 	if req.Arguments == nil {
 		req.Arguments = make(map[string]string)
 	}
+}
 
+func injectSystemArgs(req *ExecutionRequest) {
 	req.Arguments["ot_executionTrackingId"] = req.TrackingID
 	req.Arguments["ot_username"] = req.AuthenticatedUser.Username
+}
 
-	mangleInvalidArgumentValues(req)
-
-	if req.Binding == nil || req.Binding.Action == nil {
-		err = fmt.Errorf("cannot parse arguments: Binding or Action is nil")
-		req.logEntry.Output = err.Error()
-		log.Warn(err.Error())
-		return false
-	}
-
-	if len(req.Binding.Action.Exec) > 0 {
-		req.useDirectExec = true
-		req.execArgs, err = parseActionExec(req.Arguments, req.Binding.Action, req.Binding.Entity)
-	} else {
-		req.useDirectExec = false
-
-		err = checkShellArgumentSafety(req.Binding.Action)
-		if err != nil {
-			req.logEntry.Output = err.Error()
-			log.Warn(err.Error())
-			return false
-		}
+func hasBindingAndAction(req *ExecutionRequest) bool {
+	return !(req.Binding == nil || req.Binding.Action == nil)
+}
 
-		req.finalParsedCommand, err = parseActionArguments(req.Arguments, req.Binding.Action, req.Binding.Entity)
-	}
+func hasExec(req *ExecutionRequest) bool {
+	return len(req.Binding.Action.Exec) > 0
+}
 
+func parseExec(req *ExecutionRequest) bool {
+	req.useDirectExec = true
+	args, err := parseActionExec(req.Arguments, req.Binding.Action, req.Binding.Entity)
 	if err != nil {
-		req.logEntry.Output = err.Error()
-
-		log.Warn(err.Error())
-
-		return false
+		return fail(req, err)
 	}
-
+	req.execArgs = args
 	return true
 }
 
+func fail(req *ExecutionRequest, err error) bool {
+	req.logEntry.Output = err.Error()
+	log.Warn(err.Error())
+	return false
+}
+
 func stepRequestAction(req *ExecutionRequest) bool {
 	metricActionsRequested.Inc()
 
@@ -586,34 +597,17 @@ func buildEnv(args map[string]string) []string {
 func stepExec(req *ExecutionRequest) bool {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Binding.Action.Timeout)*time.Second)
 	defer cancel()
-
 	streamer := &OutputStreamer{Req: req}
-
-	var cmd *exec.Cmd
-	if req.useDirectExec {
-		cmd = wrapCommandDirect(ctx, req.execArgs)
-	} else {
-		cmd = wrapCommandInShell(ctx, req.finalParsedCommand)
-	}
-
+	cmd := buildCommand(ctx, req)
 	if cmd == nil {
 		req.logEntry.Output = "Cannot execute: no command arguments provided"
 		log.Warn("Cannot execute: no command arguments provided")
 		return false
 	}
-
-	cmd.Stdout = streamer
-	cmd.Stderr = streamer
-	cmd.Env = buildEnv(req.Arguments)
-
-	req.logEntry.ExecutionStarted = true
-
+	prepareCommand(cmd, streamer, req)
 	runerr := cmd.Start()
-
 	req.logEntry.Process = cmd.Process
-
 	waiterr := cmd.Wait()
-
 	req.logEntry.ExitCode = int32(cmd.ProcessState.ExitCode())
 	req.logEntry.Output = streamer.String()
 
@@ -643,6 +637,20 @@ func stepExec(req *ExecutionRequest) bool {
 	return true
 }
 
+func buildCommand(ctx context.Context, req *ExecutionRequest) *exec.Cmd {
+	if req.useDirectExec {
+		return wrapCommandDirect(ctx, req.execArgs)
+	}
+	return wrapCommandInShell(ctx, req.finalParsedCommand)
+}
+
+func prepareCommand(cmd *exec.Cmd, streamer *OutputStreamer, req *ExecutionRequest) {
+	cmd.Stdout = streamer
+	cmd.Stderr = streamer
+	cmd.Env = buildEnv(req.Arguments)
+	req.logEntry.ExecutionStarted = true
+}
+
 func stepExecAfter(req *ExecutionRequest) bool {
 	if req.Binding.Action.ShellAfterCompleted == "" {
 		return true