|
|
@@ -11,6 +11,7 @@ import (
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"os"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
|
|
|
authTypes "github.com/OliveTin/OliveTin/internal/auth/authpublic"
|
|
|
@@ -21,6 +22,7 @@ import (
|
|
|
|
|
|
type OAuth2Handler struct {
|
|
|
cfg *config.Config
|
|
|
+ mu sync.RWMutex
|
|
|
registeredStates map[string]*oauth2State
|
|
|
registeredProviders map[string]*oauth2.Config
|
|
|
}
|
|
|
@@ -144,11 +146,13 @@ func (h *OAuth2Handler) HandleOAuthLogin(w http.ResponseWriter, r *http.Request)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ h.mu.Lock()
|
|
|
h.registeredStates[state] = &oauth2State{
|
|
|
providerConfig: provider,
|
|
|
providerName: providerName,
|
|
|
Username: "",
|
|
|
}
|
|
|
+ h.mu.Unlock()
|
|
|
|
|
|
h.setOAuthCallbackCookie(w, r, "olivetin-sid-oauth", state)
|
|
|
|
|
|
@@ -177,7 +181,9 @@ func (h *OAuth2Handler) checkOAuthCallbackCookie(w http.ResponseWriter, r *http.
|
|
|
return nil, state, false
|
|
|
}
|
|
|
|
|
|
+ h.mu.RLock()
|
|
|
registeredState, ok := h.registeredStates[state]
|
|
|
+ h.mu.RUnlock()
|
|
|
if !ok {
|
|
|
log.Errorf("State not found in server: %v", state)
|
|
|
http.Error(w, "State not found in server", http.StatusBadRequest)
|
|
|
@@ -287,8 +293,10 @@ func (h *OAuth2Handler) HandleOAuthCallback(w http.ResponseWriter, r *http.Reque
|
|
|
userInfoClient := h.createUserInfoClient(ctx, registeredState.providerConfig, tok, clientSettings)
|
|
|
userinfo := getUserInfo(h.cfg, userInfoClient, providerConfig)
|
|
|
|
|
|
+ h.mu.Lock()
|
|
|
h.registeredStates[state].Username = userinfo.Username
|
|
|
h.registeredStates[state].Usergroup = h.computeUsergroup(userinfo, providerConfig)
|
|
|
+ h.mu.Unlock()
|
|
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
|
}
|
|
|
@@ -366,34 +374,36 @@ func getDataField(data map[string]any, field string) string {
|
|
|
return stringVal
|
|
|
}
|
|
|
|
|
|
-func (h *OAuth2Handler) CheckUserFromOAuth2Cookie(context *authTypes.AuthCheckingContext) *authTypes.AuthenticatedUser {
|
|
|
- cookie, err := context.Request.Cookie("olivetin-sid-oauth")
|
|
|
-
|
|
|
- user := &authTypes.AuthenticatedUser{}
|
|
|
-
|
|
|
- if err != nil {
|
|
|
- return nil
|
|
|
+func (h *OAuth2Handler) lookupOAuth2UserByState(state string) (*authTypes.AuthenticatedUser, bool) {
|
|
|
+ h.mu.RLock()
|
|
|
+ serverState, found := h.registeredStates[state]
|
|
|
+ if !found {
|
|
|
+ h.mu.RUnlock()
|
|
|
+ return nil, false
|
|
|
+ }
|
|
|
+ user := &authTypes.AuthenticatedUser{
|
|
|
+ Username: serverState.Username,
|
|
|
+ UsergroupLine: serverState.Usergroup,
|
|
|
+ Provider: "oauth2",
|
|
|
+ SID: state,
|
|
|
}
|
|
|
+ h.mu.RUnlock()
|
|
|
+ return user, true
|
|
|
+}
|
|
|
|
|
|
- if cookie.Value == "" {
|
|
|
+func (h *OAuth2Handler) CheckUserFromOAuth2Cookie(context *authTypes.AuthCheckingContext) *authTypes.AuthenticatedUser {
|
|
|
+ cookie, err := context.Request.Cookie("olivetin-sid-oauth")
|
|
|
+ if err != nil || cookie.Value == "" {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
- serverState, found := h.registeredStates[cookie.Value]
|
|
|
-
|
|
|
+ user, found := h.lookupOAuth2UserByState(cookie.Value)
|
|
|
if !found {
|
|
|
log.WithFields(log.Fields{
|
|
|
"sid": cookie.Value,
|
|
|
"provider": "oauth2",
|
|
|
}).Warnf("Stale session")
|
|
|
-
|
|
|
return nil
|
|
|
}
|
|
|
-
|
|
|
- user.Username = serverState.Username
|
|
|
- user.UsergroupLine = serverState.Usergroup
|
|
|
- user.Provider = "oauth2"
|
|
|
- user.SID = cookie.Value
|
|
|
-
|
|
|
return user
|
|
|
}
|