Просмотр исходного кода

Merge branch 'next' into advisory-fix-1

James Read 4 месяцев назад
Родитель
Сommit
5e0c052e68

+ 13 - 0
SECURITY.md

@@ -40,3 +40,16 @@ The following notes might be helpful when reporting a vulnerability:
 ## Disclosure of how vulnerabilities were found
 
 It is incredibly useful to not just patch security vulnerabilities, but also to understand how they were found. If you are able to share this information, it can help us and the community to better understand potential attack vectors and improve the overall security of the project.
+
+## Process
+
+Once a vulnerability is reported, the process is;
+
+* Accept or reject the report, and communicate with the reporter about next steps.
+* If accepted, patch using a temporary branch, and code review will be requested from the original reporter if they are interested.
+* The severity of the vulnerability will be assessed using CVSS, and the patch will be prioritized accordingly.
+* Once the patch is ready, it will be queued for a release onto the `next` branch (3k) or `release/2k` branch (2k)
+* The reporter will be credited in the advistory and the release notes, but not the commit message.
+* The commit message will contain a reference to the CVSS score (eg: MED) and the advisory ID. 
+
+

+ 19 - 8
service/internal/api/api.go

@@ -413,6 +413,8 @@ func (api *oliveTinAPI) ExecutionStatus(ctx ctx.Context, req *connect.Request[ap
 func (api *oliveTinAPI) Logout(ctx ctx.Context, req *connect.Request[apiv1.LogoutRequest]) (*connect.Response[apiv1.LogoutResponse], error) {
 	user := auth.UserFromApiCall(ctx, req, api.cfg)
 
+	auth.RevokeSessionForProvider(api.cfg, user.Provider, user.SID)
+
 	log.WithFields(log.Fields{
 		"username": user.Username,
 		"provider": user.Provider,
@@ -1304,7 +1306,9 @@ func serializeEntityFields(data any) map[string]string {
 }
 
 func (api *oliveTinAPI) RestartAction(ctx ctx.Context, req *connect.Request[apiv1.RestartActionRequest]) (*connect.Response[apiv1.StartActionResponse], error) {
-	var execReqLogEntry *executor.InternalLogEntry
+	ret := &apiv1.StartActionResponse{
+		ExecutionTrackingId: req.Msg.ExecutionTrackingId,
+	}
 
 	execReqLogEntry, found := api.executor.GetLog(req.Msg.ExecutionTrackingId)
 
@@ -1322,14 +1326,21 @@ func (api *oliveTinAPI) RestartAction(ctx ctx.Context, req *connect.Request[apiv
 		return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("action not found for tracking ID %s", req.Msg.ExecutionTrackingId))
 	}
 
-	log.Warnf("Restarting execution request by tracking ID: %v", req.Msg.ExecutionTrackingId)
+	authenticatedUser := auth.UserFromApiCall(ctx, req, api.cfg)
 
-	return api.StartAction(ctx, &connect.Request[apiv1.StartActionRequest]{
-		Msg: &apiv1.StartActionRequest{
-			BindingId:        execReqLogEntry.GetBindingId(),
-			UniqueTrackingId: req.Msg.ExecutionTrackingId,
-		},
-	})
+	// TrackingID is deliberately not passed to the executor, so that it generates a new one for the restarted execution.
+	// This is because the old execution (identified by the old TrackingID) is already used.
+	execReq := executor.ExecutionRequest{
+		Binding:           execReqLogEntry.Binding,
+		Arguments:         make(map[string]string),
+		AuthenticatedUser: authenticatedUser,
+		Cfg:               api.cfg,
+	}
+
+	api.executor.ExecRequest(&execReq)
+
+	ret.ExecutionTrackingId = execReq.TrackingID
+	return connect.NewResponse(ret), nil
 }
 
 func newServer(ex *executor.Executor) *oliveTinAPI {

+ 19 - 5
service/internal/auth/otjwt/jwt.go

@@ -33,6 +33,13 @@ func parseJwtToken(cfg *config.Config, jwtString string) (*jwt.Token, error) {
 	return parseJwtTokenWithHMAC(cfg, jwtString)
 }
 
+func parserOptionsWithAudience(cfg *config.Config) []jwt.ParserOption {
+	if cfg.AuthJwtAud == "" {
+		return nil
+	}
+	return []jwt.ParserOption{jwt.WithAudience(cfg.AuthJwtAud)}
+}
+
 func getClaimsFromJwtToken(cfg *config.Config, jwtString string) (jwt.MapClaims, error) {
 	token, err := parseJwtToken(cfg, jwtString)
 
@@ -56,7 +63,8 @@ func parseJwtTokenWithRemoteKey(cfg *config.Config, jwtToken string) (*jwt.Token
 		return nil, err
 	}
 
-	return jwt.Parse(jwtToken, jwksVerifier.Keyfunc, jwt.WithAudience(cfg.AuthJwtAud))
+	opts := parserOptionsWithAudience(cfg)
+	return jwt.Parse(jwtToken, jwksVerifier.Keyfunc, opts...)
 }
 
 var (
@@ -148,24 +156,30 @@ func parseJwtTokenWithLocalKey(cfg *config.Config, jwtString string) (*jwt.Token
 		return nil, err
 	}
 
-	return jwt.Parse(jwtString, func(token *jwt.Token) (interface{}, error) {
+	keyFunc := func(token *jwt.Token) (interface{}, error) {
 		if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
 			return nil, fmt.Errorf("parseJwt expected token algorithm RSA but got: %v", token.Header["alg"])
 		}
 
 		return pubKey, nil
-	})
+	}
+
+	opts := parserOptionsWithAudience(cfg)
+	return jwt.Parse(jwtString, keyFunc, opts...)
 }
 
 // Hash-based Message Authentication Code
 func parseJwtTokenWithHMAC(cfg *config.Config, jwtString string) (*jwt.Token, error) {
-	return jwt.Parse(jwtString, func(token *jwt.Token) (interface{}, error) {
+	keyFunc := func(token *jwt.Token) (interface{}, error) {
 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
 			return nil, fmt.Errorf("parseJwt expected token algorithm HMAC but got: %v", token.Header["alg"])
 		}
 
 		return []byte(cfg.AuthJwtHmacSecret), nil
-	})
+	}
+
+	opts := parserOptionsWithAudience(cfg)
+	return jwt.Parse(jwtString, keyFunc, opts...)
 }
 
 func lookupClaimValueOrDefault(claims jwt.MapClaims, key string, def string) string {

+ 21 - 1
service/internal/auth/otjwt/jwt_test.go

@@ -66,12 +66,19 @@ func newMux() *http.ServeMux {
 }
 
 func createJWTTokenWithExpiration(t *testing.T, privateKey *rsa.PrivateKey, expire int64) string {
+	return createJWTTokenWithExpirationAndAudience(t, privateKey, expire, "")
+}
+
+func createJWTTokenWithExpirationAndAudience(t *testing.T, privateKey *rsa.PrivateKey, expire int64, audience string) string {
 	token := jwt.New(jwt.SigningMethodRS256)
 	claims := token.Claims.(jwt.MapClaims)
 	claims["nbf"] = time.Now().Unix() - 1000
 	claims["exp"] = time.Now().Unix() + expire
 	claims["sub"] = "test"
 	claims["olivetinGroup"] = "test"
+	if audience != "" {
+		claims["aud"] = audience
+	}
 
 	tokenStr, err := token.SignedString(privateKey)
 	if err != nil {
@@ -108,6 +115,10 @@ func verifyJWTResponse(t *testing.T, res *http.Response, expectCode int) {
 }
 
 func testJwkValidation(t *testing.T, expire int64, expectCode int) {
+	testJwkValidationWithAudience(t, expire, expectCode, "", "")
+}
+
+func testJwkValidationWithAudience(t *testing.T, expire int64, expectCode int, configAudience, tokenAudience string) {
 	privateKey, publicKeyPath := createKeys(t)
 	defer os.Remove(publicKeyPath)
 
@@ -116,8 +127,9 @@ func testJwkValidation(t *testing.T, expire int64, expectCode int) {
 	cfg.AuthJwtClaimUsername = "sub"
 	cfg.AuthJwtClaimUserGroup = "olivetinGroup"
 	cfg.AuthJwtHeader = "Authorization"
+	cfg.AuthJwtAud = configAudience
 
-	tokenStr := createJWTTokenWithExpiration(t, privateKey, expire)
+	tokenStr := createJWTTokenWithExpirationAndAudience(t, privateKey, expire, tokenAudience)
 	handler := setupJWTTestHandler(t, cfg)
 
 	srv := httptest.NewServer(handler)
@@ -135,6 +147,14 @@ func TestJWTSignatureVerificationFails(t *testing.T) {
 	testJwkValidation(t, -500, 403)
 }
 
+func TestJWTAudienceValidationRejectsWrongAudience(t *testing.T) {
+	testJwkValidationWithAudience(t, 1000, 403, "expected-audience", "wrong-audience")
+}
+
+func TestJWTAudienceValidationAcceptsCorrectAudience(t *testing.T) {
+	testJwkValidationWithAudience(t, 1000, 200, "expected-audience", "expected-audience")
+}
+
 func createJWTTokenWithGroups(t *testing.T, privateKey *rsa.PrivateKey, groups interface{}) string {
 	token := jwt.New(jwt.SigningMethodRS256)
 	claims := token.Claims.(jwt.MapClaims)

+ 6 - 0
service/internal/auth/otoauth2/restapi_auth_oauth2.go

@@ -391,6 +391,12 @@ func (h *OAuth2Handler) lookupOAuth2UserByState(state string) (*authTypes.Authen
 	return user, true
 }
 
+func (h *OAuth2Handler) RevokeSession(sid string) {
+	h.mu.Lock()
+	defer h.mu.Unlock()
+	delete(h.registeredStates, sid)
+}
+
 func (h *OAuth2Handler) CheckUserFromOAuth2Cookie(context *authTypes.AuthCheckingContext) *authTypes.AuthenticatedUser {
 	cookie, err := context.Request.Cookie("olivetin-sid-oauth")
 	if err != nil || cookie.Value == "" {

+ 35 - 2
service/internal/auth/sessions.go

@@ -25,8 +25,9 @@ type SessionStorage struct {
 }
 
 var (
-	sessionStorage      *SessionStorage
-	sessionStorageMutex sync.RWMutex
+	sessionStorage       *SessionStorage
+	sessionStorageMutex  sync.RWMutex
+	oauth2SessionRevoker func(sid string)
 )
 
 func init() {
@@ -58,6 +59,38 @@ func RegisterUserSession(cfg *config.Config, provider string, sid string, userna
 	saveUserSessions(cfg)
 }
 
+// RegisterOAuth2SessionRevoker registers a callback to revoke OAuth2 sessions on logout.
+// OAuth2 uses its own session storage; the API calls this when provider is oauth2.
+func RegisterOAuth2SessionRevoker(fn func(sid string)) {
+	oauth2SessionRevoker = fn
+}
+
+// RevokeSessionForProvider invalidates the session for the given provider and SID (e.g. on logout).
+// Local auth uses shared SessionStorage; OAuth2 uses a separate storage and revoker.
+func RevokeSessionForProvider(cfg *config.Config, provider string, sid string) {
+	if sid == "" {
+		return
+	}
+	if provider == "oauth2" && oauth2SessionRevoker != nil {
+		oauth2SessionRevoker(sid)
+		return
+	}
+	RevokeUserSession(cfg, provider, sid)
+}
+
+// RevokeUserSession removes a session from storage so it can no longer be used (e.g. on logout).
+func RevokeUserSession(cfg *config.Config, provider string, sid string) {
+	sessionStorageMutex.Lock()
+	defer sessionStorageMutex.Unlock()
+
+	if sessionStorage.Providers[provider] != nil {
+		delete(sessionStorage.Providers[provider].Sessions, sid)
+		if cfg != nil {
+			saveUserSessions(cfg)
+		}
+	}
+}
+
 // GetUserSession retrieves a user session
 func GetUserSession(provider string, sid string) *UserSession {
 	sessionStorageMutex.Lock()

+ 11 - 4
service/internal/executor/executor.go

@@ -1015,8 +1015,15 @@ func stepTrigger(req *ExecutionRequest) bool {
 }
 
 func triggerLoop(req *ExecutionRequest) {
-	for _, triggerReq := range req.Binding.Action.Triggers {
-		binding := req.executor.FindBindingByID(triggerReq)
+	for _, triggerTitle := range req.Binding.Action.Triggers {
+		binding := req.executor.findBindingByActionTitle(triggerTitle, "")
+		if binding == nil {
+			log.WithFields(log.Fields{
+				"triggerTitle": triggerTitle,
+				"fromAction":   req.logEntry.ActionTitle,
+			}).Warnf("Trigger references unknown action title; skipping")
+			continue
+		}
 		trigger := &ExecutionRequest{
 			Binding:           binding,
 			TrackingID:        uuid.NewString(),
@@ -1059,7 +1066,7 @@ func saveLogResults(req *ExecutionRequest, filename string) {
 		}
 
 		filepath := path.Join(dir, filename+".yaml")
-		err = os.WriteFile(filepath, data, 0644)
+		err = os.WriteFile(filepath, data, 0600)
 
 		if err != nil {
 			log.Warnf("%v", err)
@@ -1073,7 +1080,7 @@ func saveLogOutput(req *ExecutionRequest, filename string) {
 	if dir != "" {
 		data := req.logEntry.Output
 		filepath := path.Join(dir, filename+".log")
-		err := os.WriteFile(filepath, []byte(data), 0644)
+		err := os.WriteFile(filepath, []byte(data), 0600)
 
 		if err != nil {
 			log.Warnf("%v", err)

+ 90 - 0
service/internal/executor/executor_test.go

@@ -2,6 +2,7 @@ package executor
 
 import (
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 
@@ -395,3 +396,92 @@ func TestFilterToDefinedArgumentsPreservesSystemArgs(t *testing.T) {
 	assert.Equal(t, "track-123", req.Arguments["ot_executionTrackingId"])
 	assert.Equal(t, "webhook", req.Arguments["ot_username"])
 }
+
+func TestTriggerExecutesTriggeredAction(t *testing.T) {
+	cfg := config.DefaultConfig()
+	e := DefaultExecutor(cfg)
+	helloAction := &config.Action{
+		Title: "Hello world",
+		Shell: "echo 'Hello World!'",
+	}
+	triggerAction := &config.Action{
+		Title:    "Simple action that triggers another action",
+		Shell:    "echo 'Hi'",
+		Triggers: []string{"Hello world"},
+	}
+	cfg.Actions = append(cfg.Actions, helloAction, triggerAction)
+	cfg.Sanitize()
+	e.RebuildActionMap()
+
+	finishedTitles := make(chan string, 4)
+	collector := &executionFinishedCollector{ch: finishedTitles}
+	e.AddListener(collector)
+
+	req := &ExecutionRequest{
+		AuthenticatedUser: auth.UserFromSystem(cfg, "testuser"),
+		Cfg:               cfg,
+		Binding:           e.FindBindingWithNoEntity(triggerAction),
+	}
+	wg, _ := e.ExecRequest(req)
+	wg.Wait()
+
+	var got []string
+	for i := 0; i < 2; i++ {
+		select {
+		case title := <-finishedTitles:
+			got = append(got, title)
+		case <-time.After(2 * time.Second):
+			t.Fatalf("timed out waiting for execution %d; got %v", i+1, got)
+		}
+	}
+	assert.Contains(t, got, "Hello world", "triggered action must run")
+	assert.Contains(t, got, "Simple action that triggers another action", "triggering action must run")
+}
+
+func TestTriggerUnknownActionTitleSkipsWithoutPanic(t *testing.T) {
+	cfg := config.DefaultConfig()
+	e := DefaultExecutor(cfg)
+	triggerAction := &config.Action{
+		Title:    "Action with bad trigger",
+		Shell:    "echo 'ok'",
+		Triggers: []string{"Nonexistent action"},
+	}
+	cfg.Actions = append(cfg.Actions, triggerAction)
+	cfg.Sanitize()
+	e.RebuildActionMap()
+
+	finishedTitles := make(chan string, 4)
+	collector := &executionFinishedCollector{ch: finishedTitles}
+	e.AddListener(collector)
+
+	req := &ExecutionRequest{
+		AuthenticatedUser: auth.UserFromSystem(cfg, "testuser"),
+		Cfg:               cfg,
+		Binding:           e.FindBindingWithNoEntity(triggerAction),
+	}
+	wg, _ := e.ExecRequest(req)
+	wg.Wait()
+
+	var got []string
+	select {
+	case title := <-finishedTitles:
+		got = append(got, title)
+	case <-time.After(500 * time.Millisecond):
+	}
+	assert.Len(t, got, 1, "only the triggering action runs; unknown trigger is skipped")
+	assert.Equal(t, "Action with bad trigger", got[0])
+}
+
+type executionFinishedCollector struct {
+	ch chan string
+}
+
+func (c *executionFinishedCollector) OnExecutionStarted(_ *InternalLogEntry) {}
+
+func (c *executionFinishedCollector) OnExecutionFinished(entry *InternalLogEntry) {
+	c.ch <- entry.ActionTitle
+}
+
+func (c *executionFinishedCollector) OnOutputChunk(_ []byte, _ string) {}
+
+func (c *executionFinishedCollector) OnActionMapRebuilt() {}

+ 1 - 0
service/internal/httpservers/frontend.go

@@ -101,6 +101,7 @@ func StartFrontendMux(cfg *config.Config, ex *executor.Executor) {
 
 	oauth2handler := otoauth2.NewOAuth2Handler(cfg)
 	auth.AddAuthChainFunction(oauth2handler.CheckUserFromOAuth2Cookie)
+	auth.RegisterOAuth2SessionRevoker(oauth2handler.RevokeSession)
 
 	mux.HandleFunc("/oauth/login", oauth2handler.HandleOAuthLogin)
 	mux.HandleFunc("/oauth/callback", oauth2handler.HandleOAuthCallback)