Browse Source

Add OAuth2 PKCE support

Frédéric Guillot 2 years ago
parent
commit
ff5d391701

+ 5 - 0
internal/http/request/context.go

@@ -20,6 +20,7 @@ const (
 	SessionIDContextKey
 	CSRFContextKey
 	OAuth2StateContextKey
+	OAuth2CodeVerifierContextKey
 	FlashMessageContextKey
 	FlashErrorMessageContextKey
 	PocketRequestTokenContextKey
@@ -94,6 +95,10 @@ func OAuth2State(r *http.Request) string {
 	return getContextStringValue(r, OAuth2StateContextKey)
 }
 
+func OAuth2CodeVerifier(r *http.Request) string {
+	return getContextStringValue(r, OAuth2CodeVerifierContextKey)
+}
+
 // FlashMessage returns the message message if any.
 func FlashMessage(r *http.Request) string {
 	return getContextStringValue(r, FlashMessageContextKey)

+ 3 - 2
internal/model/app_session.go

@@ -14,6 +14,7 @@ import (
 type SessionData struct {
 	CSRF               string `json:"csrf"`
 	OAuth2State        string `json:"oauth2_state"`
+	OAuth2CodeVerifier string `json:"oauth2_code_verifier"`
 	FlashMessage       string `json:"flash_message"`
 	FlashErrorMessage  string `json:"flash_error_message"`
 	Language           string `json:"language"`
@@ -22,8 +23,8 @@ type SessionData struct {
 }
 
 func (s SessionData) String() string {
-	return fmt.Sprintf(`CSRF=%q, OAuth2State=%q, FlashMsg=%q, FlashErrMsg=%q, Lang=%q, Theme=%q, PocketTkn=%q`,
-		s.CSRF, s.OAuth2State, s.FlashMessage, s.FlashErrorMessage, s.Language, s.Theme, s.PocketRequestToken)
+	return fmt.Sprintf(`CSRF=%q, OAuth2State=%q, OAuth2CodeVerifier=%q, FlashMsg=%q, FlashErrMsg=%q, Lang=%q, Theme=%q, PocketTkn=%q`,
+		s.CSRF, s.OAuth2State, s.OAuth2CodeVerifier, s.FlashMessage, s.FlashErrorMessage, s.Language, s.Theme, s.PocketRequestToken)
 }
 
 // Value converts the session data to JSON.

+ 54 - 0
internal/oauth2/authorization.go

@@ -0,0 +1,54 @@
+// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+package oauth2 // import "miniflux.app/v2/internal/oauth2"
+
+import (
+	"crypto/sha256"
+	"encoding/base64"
+	"io"
+
+	"golang.org/x/oauth2"
+
+	"miniflux.app/v2/internal/crypto"
+)
+
+type Authorization struct {
+	url          string
+	state        string
+	codeVerifier string
+}
+
+func (u *Authorization) RedirectURL() string {
+	return u.url
+}
+
+func (u *Authorization) State() string {
+	return u.state
+}
+
+func (u *Authorization) CodeVerifier() string {
+	return u.codeVerifier
+}
+
+func GenerateAuthorization(config *oauth2.Config) *Authorization {
+	codeVerifier := crypto.GenerateRandomStringHex(32)
+
+	sha2 := sha256.New()
+	io.WriteString(sha2, codeVerifier)
+	codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil))
+
+	state := crypto.GenerateRandomStringHex(24)
+
+	authUrl := config.AuthCodeURL(
+		state,
+		oauth2.SetAuthURLParam("code_challenge_method", "S256"),
+		oauth2.SetAuthURLParam("code_challenge", codeChallenge),
+	)
+
+	return &Authorization{
+		url:          authUrl,
+		state:        state,
+		codeVerifier: codeVerifier,
+	}
+}

+ 20 - 24
internal/oauth2/google.go

@@ -24,17 +24,30 @@ type googleProvider struct {
 	redirectURL  string
 }
 
-func (g *googleProvider) GetUserExtraKey() string {
-	return "google_id"
+func NewGoogleProvider(clientID, clientSecret, redirectURL string) *googleProvider {
+	return &googleProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL}
+}
+
+func (g *googleProvider) GetConfig() *oauth2.Config {
+	return &oauth2.Config{
+		RedirectURL:  g.redirectURL,
+		ClientID:     g.clientID,
+		ClientSecret: g.clientSecret,
+		Scopes:       []string{"email"},
+		Endpoint: oauth2.Endpoint{
+			AuthURL:  "https://accounts.google.com/o/oauth2/auth",
+			TokenURL: "https://accounts.google.com/o/oauth2/token",
+		},
+	}
 }
 
-func (g *googleProvider) GetRedirectURL(state string) string {
-	return g.config().AuthCodeURL(state)
+func (g *googleProvider) GetUserExtraKey() string {
+	return "google_id"
 }
 
-func (g *googleProvider) GetProfile(ctx context.Context, code string) (*Profile, error) {
-	conf := g.config()
-	token, err := conf.Exchange(ctx, code)
+func (g *googleProvider) GetProfile(ctx context.Context, code, codeVerifier string) (*Profile, error) {
+	conf := g.GetConfig()
+	token, err := conf.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
 	if err != nil {
 		return nil, err
 	}
@@ -67,20 +80,3 @@ func (g *googleProvider) PopulateUserWithProfileID(user *model.User, profile *Pr
 func (g *googleProvider) UnsetUserProfileID(user *model.User) {
 	user.GoogleID = ""
 }
-
-func (g *googleProvider) config() *oauth2.Config {
-	return &oauth2.Config{
-		RedirectURL:  g.redirectURL,
-		ClientID:     g.clientID,
-		ClientSecret: g.clientSecret,
-		Scopes:       []string{"email"},
-		Endpoint: oauth2.Endpoint{
-			AuthURL:  "https://accounts.google.com/o/oauth2/auth",
-			TokenURL: "https://accounts.google.com/o/oauth2/token",
-		},
-	}
-}
-
-func newGoogleProvider(clientID, clientSecret, redirectURL string) *googleProvider {
-	return &googleProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL}
-}

+ 2 - 6
internal/oauth2/manager.go

@@ -10,12 +10,10 @@ import (
 	"miniflux.app/v2/internal/logger"
 )
 
-// Manager handles OAuth2 providers.
 type Manager struct {
 	providers map[string]Provider
 }
 
-// FindProvider returns the given provider.
 func (m *Manager) FindProvider(name string) (Provider, error) {
 	if provider, found := m.providers[name]; found {
 		return provider, nil
@@ -24,18 +22,16 @@ func (m *Manager) FindProvider(name string) (Provider, error) {
 	return nil, errors.New("oauth2 provider not found")
 }
 
-// AddProvider add a new OAuth2 provider.
 func (m *Manager) AddProvider(name string, provider Provider) {
 	m.providers[name] = provider
 }
 
-// NewManager returns a new Manager.
 func NewManager(ctx context.Context, clientID, clientSecret, redirectURL, oidcDiscoveryEndpoint string) *Manager {
 	m := &Manager{providers: make(map[string]Provider)}
-	m.AddProvider("google", newGoogleProvider(clientID, clientSecret, redirectURL))
+	m.AddProvider("google", NewGoogleProvider(clientID, clientSecret, redirectURL))
 
 	if oidcDiscoveryEndpoint != "" {
-		if genericOidcProvider, err := newOidcProvider(ctx, clientID, clientSecret, redirectURL, oidcDiscoveryEndpoint); err != nil {
+		if genericOidcProvider, err := NewOidcProvider(ctx, clientID, clientSecret, redirectURL, oidcDiscoveryEndpoint); err != nil {
 			logger.Error("[OAuth2] failed to initialize OIDC provider: %v", err)
 		} else {
 			m.AddProvider("oidc", genericOidcProvider)

+ 20 - 24
internal/oauth2/oidc.go

@@ -19,17 +19,32 @@ type oidcProvider struct {
 	provider     *oidc.Provider
 }
 
+func NewOidcProvider(ctx context.Context, clientID, clientSecret, redirectURL, discoveryEndpoint string) (*oidcProvider, error) {
+	provider, err := oidc.NewProvider(ctx, discoveryEndpoint)
+	if err != nil {
+		return nil, err
+	}
+
+	return &oidcProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL, provider: provider}, nil
+}
+
 func (o *oidcProvider) GetUserExtraKey() string {
 	return "openid_connect_id"
 }
 
-func (o *oidcProvider) GetRedirectURL(state string) string {
-	return o.config().AuthCodeURL(state)
+func (o *oidcProvider) GetConfig() *oauth2.Config {
+	return &oauth2.Config{
+		RedirectURL:  o.redirectURL,
+		ClientID:     o.clientID,
+		ClientSecret: o.clientSecret,
+		Scopes:       []string{"openid", "email"},
+		Endpoint:     o.provider.Endpoint(),
+	}
 }
 
-func (o *oidcProvider) GetProfile(ctx context.Context, code string) (*Profile, error) {
-	conf := o.config()
-	token, err := conf.Exchange(ctx, code)
+func (o *oidcProvider) GetProfile(ctx context.Context, code, codeVerifier string) (*Profile, error) {
+	conf := o.GetConfig()
+	token, err := conf.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier))
 	if err != nil {
 		return nil, err
 	}
@@ -54,22 +69,3 @@ func (o *oidcProvider) PopulateUserWithProfileID(user *model.User, profile *Prof
 func (o *oidcProvider) UnsetUserProfileID(user *model.User) {
 	user.OpenIDConnectID = ""
 }
-
-func (o *oidcProvider) config() *oauth2.Config {
-	return &oauth2.Config{
-		RedirectURL:  o.redirectURL,
-		ClientID:     o.clientID,
-		ClientSecret: o.clientSecret,
-		Scopes:       []string{"openid", "email"},
-		Endpoint:     o.provider.Endpoint(),
-	}
-}
-
-func newOidcProvider(ctx context.Context, clientID, clientSecret, redirectURL, discoveryEndpoint string) (*oidcProvider, error) {
-	provider, err := oidc.NewProvider(ctx, discoveryEndpoint)
-	if err != nil {
-		return nil, err
-	}
-
-	return &oidcProvider{clientID: clientID, clientSecret: clientSecret, redirectURL: redirectURL, provider: provider}, nil
-}

+ 4 - 2
internal/oauth2/provider.go

@@ -6,14 +6,16 @@ package oauth2 // import "miniflux.app/v2/internal/oauth2"
 import (
 	"context"
 
+	"golang.org/x/oauth2"
+
 	"miniflux.app/v2/internal/model"
 )
 
 // Provider is an interface for OAuth2 providers.
 type Provider interface {
+	GetConfig() *oauth2.Config
 	GetUserExtraKey() string
-	GetRedirectURL(state string) string
-	GetProfile(ctx context.Context, code string) (*Profile, error)
+	GetProfile(ctx context.Context, code, codeVerifier string) (*Profile, error)
 	PopulateUserCreationWithProfileID(user *model.UserCreationRequest, profile *Profile)
 	PopulateUserWithProfileID(user *model.User, profile *Profile)
 	UnsetUserProfileID(user *model.User)

+ 1 - 1
internal/storage/session.go

@@ -53,7 +53,7 @@ func (s *Storage) createAppSession(session *model.Session) (*model.Session, erro
 }
 
 // UpdateAppSessionField updates only one session field.
-func (s *Storage) UpdateAppSessionField(sessionID, field string, value interface{}) error {
+func (s *Storage) UpdateAppSessionField(sessionID, field string, value any) error {
 	query := `
 		UPDATE
 			sessions

+ 2 - 1
internal/ui/middleware.go

@@ -94,7 +94,7 @@ func (m *middleware) handleAppSession(next http.Handler) http.Handler {
 					return
 				}
 
-				html.BadRequest(w, r, errors.New("Invalid or missing CSRF"))
+				html.BadRequest(w, r, errors.New("invalid or missing CSRF"))
 				return
 			}
 		}
@@ -103,6 +103,7 @@ func (m *middleware) handleAppSession(next http.Handler) http.Handler {
 		ctx = context.WithValue(ctx, request.SessionIDContextKey, session.ID)
 		ctx = context.WithValue(ctx, request.CSRFContextKey, session.Data.CSRF)
 		ctx = context.WithValue(ctx, request.OAuth2StateContextKey, session.Data.OAuth2State)
+		ctx = context.WithValue(ctx, request.OAuth2CodeVerifierContextKey, session.Data.OAuth2CodeVerifier)
 		ctx = context.WithValue(ctx, request.FlashMessageContextKey, session.Data.FlashMessage)
 		ctx = context.WithValue(ctx, request.FlashErrorMessageContextKey, session.Data.FlashErrorMessage)
 		ctx = context.WithValue(ctx, request.UserLanguageContextKey, session.Data.Language)

+ 3 - 2
internal/ui/oauth2_callback.go

@@ -4,6 +4,7 @@
 package ui // import "miniflux.app/v2/internal/ui"
 
 import (
+	"crypto/subtle"
 	"errors"
 	"net/http"
 
@@ -38,7 +39,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) {
 	}
 
 	state := request.QueryStringParam(r, "state", "")
-	if state == "" || state != request.OAuth2State(r) {
+	if subtle.ConstantTimeCompare([]byte(state), []byte(request.OAuth2State(r))) == 0 {
 		logger.Error(`[OAuth2] Invalid state value: got "%s" instead of "%s"`, state, request.OAuth2State(r))
 		html.Redirect(w, r, route.Path(h.router, "login"))
 		return
@@ -51,7 +52,7 @@ func (h *handler) oauth2Callback(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	profile, err := authProvider.GetProfile(r.Context(), code)
+	profile, err := authProvider.GetProfile(r.Context(), code, request.OAuth2CodeVerifier(r))
 	if err != nil {
 		logger.Error("[OAuth2] %v", err)
 		html.Redirect(w, r, route.Path(h.router, "login"))

+ 7 - 1
internal/ui/oauth2_redirect.go

@@ -10,6 +10,7 @@ import (
 	"miniflux.app/v2/internal/http/response/html"
 	"miniflux.app/v2/internal/http/route"
 	"miniflux.app/v2/internal/logger"
+	"miniflux.app/v2/internal/oauth2"
 	"miniflux.app/v2/internal/ui/session"
 )
 
@@ -30,5 +31,10 @@ func (h *handler) oauth2Redirect(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	html.Redirect(w, r, authProvider.GetRedirectURL(sess.NewOAuth2State()))
+	auth := oauth2.GenerateAuthorization(authProvider.GetConfig())
+
+	sess.SetOAuth2State(auth.State())
+	sess.SetOAuth2CodeVerifier(auth.CodeVerifier())
+
+	html.Redirect(w, r, auth.RedirectURL())
 }

+ 5 - 5
internal/ui/session/session.go

@@ -4,7 +4,6 @@
 package session // import "miniflux.app/v2/internal/ui/session"
 
 import (
-	"miniflux.app/v2/internal/crypto"
 	"miniflux.app/v2/internal/storage"
 )
 
@@ -14,11 +13,12 @@ type Session struct {
 	sessionID string
 }
 
-// NewOAuth2State generates a new OAuth2 state and stores the value into the database.
-func (s *Session) NewOAuth2State() string {
-	state := crypto.GenerateRandomString(32)
+func (s *Session) SetOAuth2State(state string) {
 	s.store.UpdateAppSessionField(s.sessionID, "oauth2_state", state)
-	return state
+}
+
+func (s *Session) SetOAuth2CodeVerifier(codeVerfier string) {
+	s.store.UpdateAppSessionField(s.sessionID, "oauth2_code_verifier", codeVerfier)
 }
 
 // NewFlashMessage creates a new flash message.