Преглед на файлове

bugfix: Race condition on executor logs (#385)

James Read преди 1 година
родител
ревизия
40158eda71
променени са 2 файла, в които са добавени 74 реда и са изтрити 15 реда
  1. 62 6
      internal/executor/executor.go
  2. 12 9
      internal/grpcapi/grpcApi.go

+ 62 - 6
internal/executor/executor.go

@@ -37,9 +37,11 @@ type ActionBinding struct {
 // Executor represents a helper class for executing commands. It's main method
 // is ExecRequest
 type Executor struct {
-	Logs           map[string]*InternalLogEntry
+	logs           map[string]*InternalLogEntry
 	LogsByActionId map[string][]*InternalLogEntry
 
+	logmutex sync.RWMutex
+
 	MapActionIdToBinding     map[string]*ActionBinding
 	MapActionIdToBindingLock sync.RWMutex
 
@@ -100,7 +102,7 @@ type executorStepFunc func(*ExecutionRequest) bool
 func DefaultExecutor(cfg *config.Config) *Executor {
 	e := Executor{}
 	e.Cfg = cfg
-	e.Logs = make(map[string]*InternalLogEntry)
+	e.logs = make(map[string]*InternalLogEntry)
 	e.LogsByActionId = make(map[string][]*InternalLogEntry)
 	e.MapActionIdToBinding = make(map[string]*ActionBinding)
 
@@ -132,6 +134,52 @@ func (e *Executor) AddListener(m listener) {
 	e.listeners = append(e.listeners, m)
 }
 
+func (e *Executor) GetLogsCopy() map[string]*InternalLogEntry {
+	e.logmutex.RLock()
+
+	copy := make(map[string]*InternalLogEntry)
+
+	for k, v := range e.logs {
+		copy[k] = v
+	}
+
+	e.logmutex.RUnlock()
+
+	return copy
+}
+
+func (e *Executor) GetLog(trackingID string) (*InternalLogEntry, bool) {
+	e.logmutex.RLock()
+
+	entry, found := e.logs[trackingID]
+
+	e.logmutex.RUnlock()
+
+	return entry, found
+}
+
+func (e *Executor) GetLogsByActionId(actionId string) []*InternalLogEntry {
+	e.logmutex.RLock()
+
+	logs, found := e.LogsByActionId[actionId]
+
+	e.logmutex.RUnlock()
+
+	if !found {
+		return make([]*InternalLogEntry, 0)
+	}
+
+	return logs
+}
+
+func (e *Executor) SetLog(trackingID string, entry *InternalLogEntry) {
+	e.logmutex.Lock()
+
+	e.logs[trackingID] = entry
+
+	e.logmutex.Unlock()
+}
+
 // ExecRequest processes an ExecutionRequest
 func (e *Executor) ExecRequest(req *ExecutionRequest) (*sync.WaitGroup, string) {
 	req.executor = e
@@ -147,7 +195,7 @@ func (e *Executor) ExecRequest(req *ExecutionRequest) (*sync.WaitGroup, string)
 		ActionIcon:          "💩",
 	}
 
-	_, isDuplicate := e.Logs[req.TrackingID]
+	_, isDuplicate := e.GetLog(req.TrackingID)
 
 	if isDuplicate || req.TrackingID == "" {
 		req.TrackingID = uuid.NewString()
@@ -155,7 +203,7 @@ func (e *Executor) ExecRequest(req *ExecutionRequest) (*sync.WaitGroup, string)
 
 	log.Tracef("executor.ExecRequest(): %v", req)
 
-	e.Logs[req.TrackingID] = req.logEntry
+	e.SetLog(req.TrackingID, req.logEntry)
 
 	wg := new(sync.WaitGroup)
 	wg.Add(1)
@@ -185,12 +233,16 @@ func (e *Executor) execChain(req *ExecutionRequest) {
 func getConcurrentCount(req *ExecutionRequest) int {
 	concurrentCount := 0
 
-	for _, log := range req.executor.LogsByActionId[req.Action.ID] {
+	req.executor.logmutex.RLock()
+
+	for _, log := range req.executor.GetLogsByActionId(req.Action.ID) {
 		if !log.ExecutionFinished {
 			concurrentCount += 1
 		}
 	}
 
+	req.executor.logmutex.RUnlock()
+
 	return concurrentCount
 }
 
@@ -232,7 +284,7 @@ func getExecutionsCount(rate config.RateSpec, req *ExecutionRequest) int {
 
 	then := time.Now().Add(-duration)
 
-	for _, logEntry := range req.executor.LogsByActionId[req.Action.ID] {
+	for _, logEntry := range req.executor.GetLogsByActionId(req.Action.ID) {
 		if logEntry.DatetimeStarted.After(then) && !logEntry.Blocked {
 
 			executions += 1
@@ -308,12 +360,16 @@ func stepRequestAction(req *ExecutionRequest) bool {
 	req.logEntry.ActionIcon = req.Action.Icon
 	req.logEntry.ActionId = req.Action.ID
 
+	req.executor.logmutex.Lock()
+
 	if _, containsKey := req.executor.LogsByActionId[req.Action.ID]; !containsKey {
 		req.executor.LogsByActionId[req.Action.ID] = make([]*InternalLogEntry, 0)
 	}
 
 	req.executor.LogsByActionId[req.Action.ID] = append(req.executor.LogsByActionId[req.Action.ID], req.logEntry)
 
+	req.executor.logmutex.Unlock()
+
 	log.WithFields(log.Fields{
 		"actionTitle": req.logEntry.ActionTitle,
 		"tags":        req.Tags,

+ 12 - 9
internal/grpcapi/grpcApi.go

@@ -35,7 +35,7 @@ func (api *oliveTinAPI) KillAction(ctx ctx.Context, req *pb.KillActionRequest) (
 		ExecutionTrackingId: req.ExecutionTrackingId,
 	}
 
-	execReqLogEntry, found := api.executor.Logs[req.ExecutionTrackingId]
+	execReqLogEntry, found := api.executor.GetLog(req.ExecutionTrackingId)
 
 	ret.Found = found
 
@@ -99,7 +99,7 @@ func (api *oliveTinAPI) StartActionAndWait(ctx ctx.Context, req *pb.StartActionA
 	wg, _ := api.executor.ExecRequest(&execReq)
 	wg.Wait()
 
-	internalLogEntry, ok := api.executor.Logs[execReq.TrackingID]
+	internalLogEntry, ok := api.executor.GetLog(execReq.TrackingID)
 
 	if ok {
 		return &pb.StartActionAndWaitResponse{
@@ -142,7 +142,7 @@ func (api *oliveTinAPI) StartActionByGetAndWait(ctx ctx.Context, req *pb.StartAc
 	wg, _ := api.executor.ExecRequest(&execReq)
 	wg.Wait()
 
-	internalLogEntry, ok := api.executor.Logs[execReq.TrackingID]
+	internalLogEntry, ok := api.executor.GetLog(execReq.TrackingID)
 
 	if ok {
 		return &pb.StartActionByGetAndWaitResponse{
@@ -172,7 +172,7 @@ func internalLogEntryToPb(logEntry *executor.InternalLogEntry) *pb.LogEntry {
 }
 
 func getExecutionStatusByTrackingID(api *oliveTinAPI, executionTrackingId string) *executor.InternalLogEntry {
-	logEntry, ok := api.executor.Logs[executionTrackingId]
+	logEntry, ok := api.executor.GetLog(executionTrackingId)
 
 	if !ok {
 		return nil
@@ -184,10 +184,13 @@ func getExecutionStatusByTrackingID(api *oliveTinAPI, executionTrackingId string
 func getMostRecentExecutionStatusById(api *oliveTinAPI, actionId string) *executor.InternalLogEntry {
 	var ile *executor.InternalLogEntry
 
-	for _, candidateLe := range api.executor.Logs {
-		if actionId == candidateLe.ActionId {
-			ile = candidateLe
-		}
+	logs := api.executor.GetLogsByActionId(actionId)
+
+	if len(logs) == 0 {
+		return nil
+	} else {
+		// Get last log entry
+		ile = logs[len(logs)-1]
 	}
 
 	return ile
@@ -262,7 +265,7 @@ func (api *oliveTinAPI) GetLogs(ctx ctx.Context, req *pb.GetLogsRequest) (*pb.Ge
 
 	// TODO Limit to 10 entries or something to prevent browser lag.
 
-	for trackingId, logEntry := range api.executor.Logs {
+	for trackingId, logEntry := range api.executor.GetLogsCopy() {
 		action := cfg.FindAction(logEntry.ActionTitle)
 
 		if action == nil || acl.IsAllowedLogs(cfg, user, action) {