|
|
@@ -219,11 +219,43 @@ func getOAuthCertBundle(providerConfig *config.OAuth2Provider) *x509.CertPool {
|
|
|
return caCertPool
|
|
|
}
|
|
|
|
|
|
+func (h *OAuth2Handler) exchangeOAuthCode(ctx context.Context, providerConfig *oauth2.Config, code string, clientSettings *HttpClientSettings) (*oauth2.Token, error) {
|
|
|
+ exchangeClient := &http.Client{
|
|
|
+ Transport: clientSettings.Transport,
|
|
|
+ Timeout: clientSettings.Timeout,
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx = context.WithValue(ctx, oauth2.HTTPClient, exchangeClient)
|
|
|
+
|
|
|
+ return providerConfig.Exchange(ctx, code)
|
|
|
+}
|
|
|
+
|
|
|
+func (h *OAuth2Handler) createUserInfoClient(ctx context.Context, providerConfig *oauth2.Config, tok *oauth2.Token, clientSettings *HttpClientSettings) *http.Client {
|
|
|
+ return &http.Client{
|
|
|
+ Transport: &oauth2.Transport{
|
|
|
+ Source: providerConfig.TokenSource(ctx, tok),
|
|
|
+ Base: clientSettings.Transport,
|
|
|
+ },
|
|
|
+ Timeout: clientSettings.Timeout,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (h *OAuth2Handler) computeUsergroup(userinfo *UserInfo, providerConfig *config.OAuth2Provider) string {
|
|
|
+ usergroup := userinfo.Usergroup
|
|
|
+ if providerConfig != nil && providerConfig.AddToUsergroup != "" {
|
|
|
+ if usergroup != "" {
|
|
|
+ usergroup = usergroup + " " + providerConfig.AddToUsergroup
|
|
|
+ } else {
|
|
|
+ usergroup = providerConfig.AddToUsergroup
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return usergroup
|
|
|
+}
|
|
|
+
|
|
|
func (h *OAuth2Handler) HandleOAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|
|
log.Infof("OAuth2 Callback received")
|
|
|
|
|
|
registeredState, state, ok := h.checkOAuthCallbackCookie(w, r)
|
|
|
-
|
|
|
if !ok {
|
|
|
return
|
|
|
}
|
|
|
@@ -236,48 +268,21 @@ func (h *OAuth2Handler) HandleOAuthCallback(w http.ResponseWriter, r *http.Reque
|
|
|
}).Debug("OAuth2 Token Code")
|
|
|
|
|
|
providerConfig := h.cfg.AuthOAuth2Providers[registeredState.providerName]
|
|
|
-
|
|
|
clientSettings := getOAuth2HttpClient(providerConfig)
|
|
|
|
|
|
- exchangeClient := &http.Client{
|
|
|
- Transport: clientSettings.Transport,
|
|
|
- Timeout: clientSettings.Timeout,
|
|
|
- }
|
|
|
-
|
|
|
ctx := context.Background()
|
|
|
- ctx = context.WithValue(ctx, oauth2.HTTPClient, exchangeClient)
|
|
|
-
|
|
|
- tok, err := registeredState.providerConfig.Exchange(ctx, code)
|
|
|
-
|
|
|
+ tok, err := h.exchangeOAuthCode(ctx, registeredState.providerConfig, code, clientSettings)
|
|
|
if err != nil {
|
|
|
log.Errorf("Failed to exchange code: %v", err)
|
|
|
http.Error(w, "Failed to exchange code", http.StatusBadRequest)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- userInfoClient := &http.Client{
|
|
|
- Transport: &oauth2.Transport{
|
|
|
- Source: registeredState.providerConfig.TokenSource(ctx, tok),
|
|
|
- Base: clientSettings.Transport,
|
|
|
- },
|
|
|
- Timeout: clientSettings.Timeout,
|
|
|
- }
|
|
|
-
|
|
|
+ userInfoClient := h.createUserInfoClient(ctx, registeredState.providerConfig, tok, clientSettings)
|
|
|
userinfo := getUserInfo(h.cfg, userInfoClient, providerConfig)
|
|
|
|
|
|
h.registeredStates[state].Username = userinfo.Username
|
|
|
-
|
|
|
- usergroup := userinfo.Usergroup
|
|
|
- if providerConfig != nil && providerConfig.AddToUsergroup != "" {
|
|
|
- // Append configured usergroup name if addToUsergroup is set
|
|
|
- if usergroup != "" {
|
|
|
- usergroup = usergroup + " " + providerConfig.AddToUsergroup
|
|
|
- } else {
|
|
|
- usergroup = providerConfig.AddToUsergroup
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- h.registeredStates[state].Usergroup = usergroup
|
|
|
+ h.registeredStates[state].Usergroup = h.computeUsergroup(userinfo, providerConfig)
|
|
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
|
}
|