Kaynağa Gözat

Merge branch 'next' of github.com:OliveTin/OliveTin into next

jamesread 4 ay önce
ebeveyn
işleme
131393fb2d

+ 89 - 43
service/internal/api/api.go

@@ -70,20 +70,21 @@ func (api *oliveTinAPI) KillAction(ctx ctx.Context, req *connect.Request[apiv1.K
 	execReqLogEntry, ret.Found = api.executor.GetLog(req.Msg.ExecutionTrackingId)
 
 	if !ret.Found {
-		log.Warnf("Killing execution request not possible - not found by tracking ID: %v", req.Msg.ExecutionTrackingId)
-		return connect.NewResponse(ret), nil
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("execution not found for tracking ID %s", req.Msg.ExecutionTrackingId))
 	}
 
-	log.Warnf("Killing execution request by tracking ID: %v", req.Msg.ExecutionTrackingId)
+	if execReqLogEntry.Binding == nil {
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("log entry has no binding for tracking ID %s", req.Msg.ExecutionTrackingId))
+	}
 
 	action := execReqLogEntry.Binding.Action
 
 	if action == nil {
-		log.Warnf("Killing execution request not possible - action not found: %v", execReqLogEntry.ActionTitle)
-		ret.Killed = false
-		return connect.NewResponse(ret), nil
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action not found for tracking ID %s", req.Msg.ExecutionTrackingId))
 	}
 
+	log.Warnf("Killing execution request by tracking ID: %v", req.Msg.ExecutionTrackingId)
+
 	user := auth.UserFromApiCall(ctx, req, api.cfg)
 
 	api.killActionByTrackingId(user, action, execReqLogEntry, ret)
@@ -205,42 +206,58 @@ func (api *oliveTinAPI) LocalUserLogin(ctx ctx.Context, req *connect.Request[api
 	return response, nil
 }
 
-func (api *oliveTinAPI) StartActionAndWait(ctx ctx.Context, req *connect.Request[apiv1.StartActionAndWaitRequest]) (*connect.Response[apiv1.StartActionAndWaitResponse], error) {
-	args := make(map[string]string)
-
-	for _, arg := range req.Msg.Arguments {
-		args[arg.Name] = arg.Value
-	}
-
-	user := auth.UserFromApiCall(ctx, req, api.cfg)
-
+func (api *oliveTinAPI) startActionAndWaitRun(binding *executor.ActionBinding, args map[string]string, user *authpublic.AuthenticatedUser) (*executor.InternalLogEntry, bool) {
 	execReq := executor.ExecutionRequest{
-		Binding:           api.executor.FindBindingByID(req.Msg.ActionId),
+		Binding:           binding,
 		TrackingID:        uuid.NewString(),
 		Arguments:         args,
 		AuthenticatedUser: user,
 		Cfg:               api.cfg,
 	}
-
 	wg, _ := api.executor.ExecRequest(&execReq)
 	wg.Wait()
+	return api.executor.GetLog(execReq.TrackingID)
+}
 
-	internalLogEntry, ok := api.executor.GetLog(execReq.TrackingID)
+func (api *oliveTinAPI) findBindingOrNotFound(actionId string) (*executor.ActionBinding, error) {
+	binding := api.executor.FindBindingByID(actionId)
+	if binding == nil || binding.Action == nil {
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action with ID %s not found", actionId))
+	}
+	return binding, nil
+}
 
-	if ok {
-		return connect.NewResponse(&apiv1.StartActionAndWaitResponse{
-			LogEntry: api.internalLogEntryToPb(internalLogEntry, user),
-		}), nil
-	} else {
-		return nil, fmt.Errorf("execution not found")
+func (api *oliveTinAPI) StartActionAndWait(ctx ctx.Context, req *connect.Request[apiv1.StartActionAndWaitRequest]) (*connect.Response[apiv1.StartActionAndWaitResponse], error) {
+	binding, err := api.findBindingOrNotFound(req.Msg.ActionId)
+	if err != nil {
+		return nil, err
+	}
+
+	args := make(map[string]string)
+	for _, arg := range req.Msg.Arguments {
+		args[arg.Name] = arg.Value
+	}
+	user := auth.UserFromApiCall(ctx, req, api.cfg)
+
+	internalLogEntry, ok := api.startActionAndWaitRun(binding, args, user)
+	if !ok {
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("execution not found"))
 	}
+	return connect.NewResponse(&apiv1.StartActionAndWaitResponse{
+		LogEntry: api.internalLogEntryToPb(internalLogEntry, user),
+	}), nil
 }
 
 func (api *oliveTinAPI) StartActionByGet(ctx ctx.Context, req *connect.Request[apiv1.StartActionByGetRequest]) (*connect.Response[apiv1.StartActionByGetResponse], error) {
+	binding := api.executor.FindBindingByID(req.Msg.ActionId)
+	if binding == nil || binding.Action == nil {
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action with ID %s not found", req.Msg.ActionId))
+	}
+
 	args := make(map[string]string)
 
 	execReq := executor.ExecutionRequest{
-		Binding:           api.executor.FindBindingByID(req.Msg.ActionId),
+		Binding:           binding,
 		TrackingID:        uuid.NewString(),
 		Arguments:         args,
 		AuthenticatedUser: auth.UserFromApiCall(ctx, req, api.cfg),
@@ -255,12 +272,17 @@ func (api *oliveTinAPI) StartActionByGet(ctx ctx.Context, req *connect.Request[a
 }
 
 func (api *oliveTinAPI) StartActionByGetAndWait(ctx ctx.Context, req *connect.Request[apiv1.StartActionByGetAndWaitRequest]) (*connect.Response[apiv1.StartActionByGetAndWaitResponse], error) {
+	binding := api.executor.FindBindingByID(req.Msg.ActionId)
+	if binding == nil || binding.Action == nil {
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action with ID %s not found", req.Msg.ActionId))
+	}
+
 	args := make(map[string]string)
 
 	user := auth.UserFromApiCall(ctx, req, api.cfg)
 
 	execReq := executor.ExecutionRequest{
-		Binding:           api.executor.FindBindingByID(req.Msg.ActionId),
+		Binding:           binding,
 		TrackingID:        uuid.NewString(),
 		Arguments:         args,
 		AuthenticatedUser: user,
@@ -276,9 +298,8 @@ func (api *oliveTinAPI) StartActionByGetAndWait(ctx ctx.Context, req *connect.Re
 		return connect.NewResponse(&apiv1.StartActionByGetAndWaitResponse{
 			LogEntry: api.internalLogEntryToPb(internalLogEntry, user),
 		}), nil
-	} else {
-		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("execution not found"))
 	}
+	return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("execution not found"))
 }
 
 func calculateRateLimitExpires(api *oliveTinAPI, logEntry *executor.InternalLogEntry) string {
@@ -392,6 +413,8 @@ func (api *oliveTinAPI) ExecutionStatus(ctx ctx.Context, req *connect.Request[ap
 func (api *oliveTinAPI) Logout(ctx ctx.Context, req *connect.Request[apiv1.LogoutRequest]) (*connect.Response[apiv1.LogoutResponse], error) {
 	user := auth.UserFromApiCall(ctx, req, api.cfg)
 
+	auth.RevokeSessionForProvider(api.cfg, user.Provider, user.SID)
+
 	log.WithFields(log.Fields{
 		"username": user.Username,
 		"provider": user.Provider,
@@ -436,7 +459,7 @@ func (api *oliveTinAPI) GetActionBinding(ctx ctx.Context, req *connect.Request[a
 
 	binding := api.executor.FindBindingByID(req.Msg.BindingId)
 
-	if binding == nil {
+	if binding == nil || binding.Action == nil {
 		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action with ID %s not found", req.Msg.BindingId))
 	}
 
@@ -646,7 +669,16 @@ error messages more quickly before starting the action.
 It uses the same validation logic as the executor, including mangling argument
 values (e.g., datetime formatting, checkbox title-to-value conversion).
 */
+func (api *oliveTinAPI) argumentNotFoundForValidation(msg *apiv1.ValidateArgumentTypeRequest) bool {
+	arg, _ := api.findArgumentForValidation(msg.BindingId, msg.ArgumentName)
+	return arg == nil && (msg.BindingId != "" || msg.ArgumentName != "")
+}
+
 func (api *oliveTinAPI) ValidateArgumentType(ctx ctx.Context, req *connect.Request[apiv1.ValidateArgumentTypeRequest]) (*connect.Response[apiv1.ValidateArgumentTypeResponse], error) {
+	if api.argumentNotFoundForValidation(req.Msg) {
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action or argument not found for binding ID %s", req.Msg.BindingId))
+	}
+
 	err := api.validateArgumentTypeInternal(req.Msg)
 	desc := ""
 	if err != nil {
@@ -747,6 +779,13 @@ func (api *oliveTinAPI) DumpVars(ctx ctx.Context, req *connect.Request[apiv1.Dum
 	return connect.NewResponse(res), nil
 }
 
+func debugBindingActionTitle(binding *executor.ActionBinding) string {
+	if binding == nil || binding.Action == nil {
+		return ""
+	}
+	return binding.Action.Title
+}
+
 func (api *oliveTinAPI) DumpPublicIdActionMap(ctx ctx.Context, req *connect.Request[apiv1.DumpPublicIdActionMapRequest]) (*connect.Response[apiv1.DumpPublicIdActionMapResponse], error) {
 	res := &apiv1.DumpPublicIdActionMapResponse{}
 	res.Contents = make(map[string]*apiv1.DebugBinding)
@@ -761,7 +800,7 @@ func (api *oliveTinAPI) DumpPublicIdActionMap(ctx ctx.Context, req *connect.Requ
 
 	for k, v := range api.executor.MapActionBindings {
 		res.Contents[k] = &apiv1.DebugBinding{
-			ActionTitle: v.Action.Title,
+			ActionTitle: debugBindingActionTitle(v),
 		}
 	}
 
@@ -1271,30 +1310,37 @@ func (api *oliveTinAPI) RestartAction(ctx ctx.Context, req *connect.Request[apiv
 		ExecutionTrackingId: req.Msg.ExecutionTrackingId,
 	}
 
-	var execReqLogEntry *executor.InternalLogEntry
-
 	execReqLogEntry, found := api.executor.GetLog(req.Msg.ExecutionTrackingId)
 
 	if !found {
-		log.Warnf("Restarting execution request not possible - not found by tracking ID: %v", req.Msg.ExecutionTrackingId)
-		return connect.NewResponse(ret), nil
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("execution not found for tracking ID %s", req.Msg.ExecutionTrackingId))
 	}
 
-	log.Warnf("Restarting execution request by tracking ID: %v", req.Msg.ExecutionTrackingId)
+	if execReqLogEntry.Binding == nil {
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("log entry has no binding for tracking ID %s", req.Msg.ExecutionTrackingId))
+	}
 
 	action := execReqLogEntry.Binding.Action
 
 	if action == nil {
-		log.Warnf("Restarting execution request not possible - action not found: %v", execReqLogEntry.ActionTitle)
-		return connect.NewResponse(ret), nil
+		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action not found for tracking ID %s", req.Msg.ExecutionTrackingId))
 	}
 
-	return api.StartAction(ctx, &connect.Request[apiv1.StartActionRequest]{
-		Msg: &apiv1.StartActionRequest{
-			BindingId:        execReqLogEntry.GetBindingId(),
-			UniqueTrackingId: req.Msg.ExecutionTrackingId,
-		},
-	})
+	authenticatedUser := auth.UserFromApiCall(ctx, req, api.cfg)
+
+	// TrackingID is deliberately not passed to the executor, so that it generates a new one for the restarted execution.
+	// This is because the old execution (identified by the old TrackingID) is already used.
+	execReq := executor.ExecutionRequest{
+		Binding:           execReqLogEntry.Binding,
+		Arguments:         make(map[string]string),
+		AuthenticatedUser: authenticatedUser,
+		Cfg:               api.cfg,
+	}
+
+	api.executor.ExecRequest(&execReq)
+
+	ret.ExecutionTrackingId = execReq.TrackingID
+	return connect.NewResponse(ret), nil
 }
 
 func newServer(ex *executor.Executor) *oliveTinAPI {

+ 31 - 19
service/internal/api/apiActions.go

@@ -28,18 +28,19 @@ func (rr *DashboardRenderRequest) findAction(title string) *apiv1.Action {
 	return rr.findActionForEntity(title, nil)
 }
 
+func bindingMatchesTitleAndEntity(binding *executor.ActionBinding, title string, entity *entities.Entity) bool {
+	return binding != nil && binding.Action != nil && binding.Action.Title == title && matchesEntity(binding, entity)
+}
+
 func (rr *DashboardRenderRequest) findActionForEntity(title string, entity *entities.Entity) *apiv1.Action {
 	rr.ex.MapActionBindingsLock.RLock()
 	defer rr.ex.MapActionBindingsLock.RUnlock()
 
 	for _, binding := range rr.ex.MapActionBindings {
-		if binding.Action.Title != title {
+		if !bindingMatchesTitleAndEntity(binding, title, entity) {
 			continue
 		}
-
-		if matchesEntity(binding, entity) {
-			return buildAction(binding, rr)
-		}
+		return buildAction(binding, rr)
 	}
 
 	return nil
@@ -117,26 +118,37 @@ func getDefaultArgumentValue(cfgArg config.ActionArgument, entity *entities.Enti
 	return defaultValue
 }
 
-func buildAction(actionBinding *executor.ActionBinding, rr *DashboardRenderRequest) *apiv1.Action {
-	action := actionBinding.Action
+func formatRateLimitExpiry(expiryUnix int64) string {
+	if expiryUnix <= 0 {
+		return ""
+	}
+	return time.Unix(expiryUnix, 0).Format("2006-01-02 15:04:05")
+}
 
-	aclCanExec := acl.IsAllowedExec(rr.cfg, rr.AuthenticatedUser, action)
-	enabledExprCanExec := evaluateEnabledExpression(action, actionBinding.Entity)
+func actionFromBinding(actionBinding *executor.ActionBinding) (*executor.ActionBinding, *config.Action) {
+	if actionBinding == nil || actionBinding.Action == nil {
+		return nil, nil
+	}
+	return actionBinding, actionBinding.Action
+}
 
-	// Calculate rate limit expiry time
-	expiryUnix := rr.ex.GetTimeUntilAvailable(actionBinding)
-	datetimeRateLimitExpires := ""
-	if expiryUnix > 0 {
-		datetimeRateLimitExpires = time.Unix(expiryUnix, 0).Format("2006-01-02 15:04:05")
+func buildAction(actionBinding *executor.ActionBinding, rr *DashboardRenderRequest) *apiv1.Action {
+	binding, action := actionFromBinding(actionBinding)
+	if binding == nil {
+		return nil
 	}
 
+	aclCanExec := acl.IsAllowedExec(rr.cfg, rr.AuthenticatedUser, action)
+	enabledExprCanExec := evaluateEnabledExpression(action, binding.Entity)
+	datetimeRateLimitExpires := formatRateLimitExpiry(rr.ex.GetTimeUntilAvailable(binding))
+
 	btn := apiv1.Action{
-		BindingId:                actionBinding.ID,
-		Title:                    tpl.ParseTemplateOfActionBeforeExec(action.Title, actionBinding.Entity),
-		Icon:                     tpl.ParseTemplateOfActionBeforeExec(action.Icon, actionBinding.Entity),
+		BindingId:                binding.ID,
+		Title:                    tpl.ParseTemplateOfActionBeforeExec(action.Title, binding.Entity),
+		Icon:                     tpl.ParseTemplateOfActionBeforeExec(action.Icon, binding.Entity),
 		CanExec:                  aclCanExec && enabledExprCanExec,
 		PopupOnStart:             action.PopupOnStart,
-		Order:                    int32(actionBinding.ConfigOrder),
+		Order:                    int32(binding.ConfigOrder),
 		Timeout:                  int32(action.Timeout),
 		DatetimeRateLimitExpires: datetimeRateLimitExpires,
 	}
@@ -147,7 +159,7 @@ func buildAction(actionBinding *executor.ActionBinding, rr *DashboardRenderReque
 			Title:                 cfgArg.Title,
 			Type:                  cfgArg.Type,
 			Description:           cfgArg.Description,
-			DefaultValue:          getDefaultArgumentValue(cfgArg, actionBinding.Entity),
+			DefaultValue:          getDefaultArgumentValue(cfgArg, binding.Entity),
 			Choices:               buildChoices(cfgArg),
 			Suggestions:           cfgArg.Suggestions,
 			SuggestionsBrowserKey: cfgArg.SuggestionsBrowserKey,

+ 4 - 1
service/internal/api/dashboards.go

@@ -130,7 +130,7 @@ func buildDefaultDashboard(rr *DashboardRenderRequest) *apiv1.Dashboard {
 	}
 
 	for _, binding := range rr.ex.MapActionBindings {
-		if binding.Action.Hidden {
+		if binding == nil || binding.Action == nil || binding.Action.Hidden {
 			continue
 		}
 
@@ -139,6 +139,9 @@ func buildDefaultDashboard(rr *DashboardRenderRequest) *apiv1.Dashboard {
 		}
 
 		action := buildAction(binding, rr)
+		if action == nil {
+			continue
+		}
 
 		fieldset.Contents = append(fieldset.Contents, &apiv1.DashboardComponent{
 			Type:   "link",

+ 6 - 0
service/internal/auth/otoauth2/restapi_auth_oauth2.go

@@ -391,6 +391,12 @@ func (h *OAuth2Handler) lookupOAuth2UserByState(state string) (*authTypes.Authen
 	return user, true
 }
 
+func (h *OAuth2Handler) RevokeSession(sid string) {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	delete(h.registeredStates, sid)
+}
+
 func (h *OAuth2Handler) CheckUserFromOAuth2Cookie(context *authTypes.AuthCheckingContext) *authTypes.AuthenticatedUser {
 	cookie, err := context.Request.Cookie("olivetin-sid-oauth")
 	if err != nil || cookie.Value == "" {

+ 35 - 2
service/internal/auth/sessions.go

@@ -25,8 +25,9 @@ type SessionStorage struct {
 }
 
 var (
-	sessionStorage      *SessionStorage
-	sessionStorageMutex sync.RWMutex
+	sessionStorage       *SessionStorage
+	sessionStorageMutex  sync.RWMutex
+	oauth2SessionRevoker func(sid string)
 )
 
 func init() {
@@ -58,6 +59,38 @@ func RegisterUserSession(cfg *config.Config, provider string, sid string, userna
 	saveUserSessions(cfg)
 }
 
+// RegisterOAuth2SessionRevoker registers a callback to revoke OAuth2 sessions on logout.
+// OAuth2 uses its own session storage; the API calls this when provider is oauth2.
+func RegisterOAuth2SessionRevoker(fn func(sid string)) {
+	oauth2SessionRevoker = fn
+}
+
+// RevokeSessionForProvider invalidates the session for the given provider and SID (e.g. on logout).
+// Local auth uses shared SessionStorage; OAuth2 uses a separate storage and revoker.
+func RevokeSessionForProvider(cfg *config.Config, provider string, sid string) {
+	if sid == "" {
+		return
+	}
+	if provider == "oauth2" && oauth2SessionRevoker != nil {
+		oauth2SessionRevoker(sid)
+		return
+	}
+	RevokeUserSession(cfg, provider, sid)
+}
+
+// RevokeUserSession removes a session from storage so it can no longer be used (e.g. on logout).
+func RevokeUserSession(cfg *config.Config, provider string, sid string) {
+	sessionStorageMutex.Lock()
+	defer sessionStorageMutex.Unlock()
+
+	if sessionStorage.Providers[provider] != nil {
+		delete(sessionStorage.Providers[provider].Sessions, sid)
+		if cfg != nil {
+			saveUserSessions(cfg)
+		}
+	}
+}
+
 // GetUserSession retrieves a user session
 func GetUserSession(provider string, sid string) *UserSession {
 	sessionStorageMutex.Lock()

+ 1 - 0
service/internal/httpservers/frontend.go

@@ -101,6 +101,7 @@ func StartFrontendMux(cfg *config.Config, ex *executor.Executor) {
 
 	oauth2handler := otoauth2.NewOAuth2Handler(cfg)
 	auth.AddAuthChainFunction(oauth2handler.CheckUserFromOAuth2Cookie)
+	auth.RegisterOAuth2SessionRevoker(oauth2handler.RevokeSession)
 
 	mux.HandleFunc("/oauth/login", oauth2handler.HandleOAuthLogin)
 	mux.HandleFunc("/oauth/callback", oauth2handler.HandleOAuthCallback)