jamesread 6 місяців тому
батько
коміт
e9d90060fd
1 змінених файлів з 36 додано та 31 видалено
  1. 36 31
      service/internal/auth/otoauth2/restapi_auth_oauth2.go

+ 36 - 31
service/internal/auth/otoauth2/restapi_auth_oauth2.go

@@ -219,11 +219,43 @@ func getOAuthCertBundle(providerConfig *config.OAuth2Provider) *x509.CertPool {
 	return caCertPool
 }
 
+func (h *OAuth2Handler) exchangeOAuthCode(ctx context.Context, providerConfig *oauth2.Config, code string, clientSettings *HttpClientSettings) (*oauth2.Token, error) {
+	exchangeClient := &http.Client{
+		Transport: clientSettings.Transport,
+		Timeout:   clientSettings.Timeout,
+	}
+
+	ctx = context.WithValue(ctx, oauth2.HTTPClient, exchangeClient)
+
+	return providerConfig.Exchange(ctx, code)
+}
+
+func (h *OAuth2Handler) createUserInfoClient(ctx context.Context, providerConfig *oauth2.Config, tok *oauth2.Token, clientSettings *HttpClientSettings) *http.Client {
+	return &http.Client{
+		Transport: &oauth2.Transport{
+			Source: providerConfig.TokenSource(ctx, tok),
+			Base:   clientSettings.Transport,
+		},
+		Timeout: clientSettings.Timeout,
+	}
+}
+
+func (h *OAuth2Handler) computeUsergroup(userinfo *UserInfo, providerConfig *config.OAuth2Provider) string {
+	usergroup := userinfo.Usergroup
+	if providerConfig != nil && providerConfig.AddToUsergroup != "" {
+		if usergroup != "" {
+			usergroup = usergroup + " " + providerConfig.AddToUsergroup
+		} else {
+			usergroup = providerConfig.AddToUsergroup
+		}
+	}
+	return usergroup
+}
+
 func (h *OAuth2Handler) HandleOAuthCallback(w http.ResponseWriter, r *http.Request) {
 	log.Infof("OAuth2 Callback received")
 
 	registeredState, state, ok := h.checkOAuthCallbackCookie(w, r)
-
 	if !ok {
 		return
 	}
@@ -236,48 +268,21 @@ func (h *OAuth2Handler) HandleOAuthCallback(w http.ResponseWriter, r *http.Reque
 	}).Debug("OAuth2 Token Code")
 
 	providerConfig := h.cfg.AuthOAuth2Providers[registeredState.providerName]
-
 	clientSettings := getOAuth2HttpClient(providerConfig)
 
-	exchangeClient := &http.Client{
-		Transport: clientSettings.Transport,
-		Timeout:   clientSettings.Timeout,
-	}
-
 	ctx := context.Background()
-	ctx = context.WithValue(ctx, oauth2.HTTPClient, exchangeClient)
-
-	tok, err := registeredState.providerConfig.Exchange(ctx, code)
-
+	tok, err := h.exchangeOAuthCode(ctx, registeredState.providerConfig, code, clientSettings)
 	if err != nil {
 		log.Errorf("Failed to exchange code: %v", err)
 		http.Error(w, "Failed to exchange code", http.StatusBadRequest)
 		return
 	}
 
-	userInfoClient := &http.Client{
-		Transport: &oauth2.Transport{
-			Source: registeredState.providerConfig.TokenSource(ctx, tok),
-			Base:   clientSettings.Transport,
-		},
-		Timeout: clientSettings.Timeout,
-	}
-
+	userInfoClient := h.createUserInfoClient(ctx, registeredState.providerConfig, tok, clientSettings)
 	userinfo := getUserInfo(h.cfg, userInfoClient, providerConfig)
 
 	h.registeredStates[state].Username = userinfo.Username
-	
-	usergroup := userinfo.Usergroup
-	if providerConfig != nil && providerConfig.AddToUsergroup != "" {
-		// Append configured usergroup name if addToUsergroup is set
-		if usergroup != "" {
-			usergroup = usergroup + " " + providerConfig.AddToUsergroup
-		} else {
-			usergroup = providerConfig.AddToUsergroup
-		}
-	}
-	
-	h.registeredStates[state].Usergroup = usergroup
+	h.registeredStates[state].Usergroup = h.computeUsergroup(userinfo, providerConfig)
 
 	http.Redirect(w, r, "/", http.StatusFound)
 }