| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415 |
- package otoauth2
- import (
- "context"
- "crypto/rand"
- "crypto/tls"
- "crypto/x509"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "os"
- "sync"
- "time"
- authTypes "github.com/OliveTin/OliveTin/internal/auth/authpublic"
- config "github.com/OliveTin/OliveTin/internal/config"
- log "github.com/sirupsen/logrus"
- "golang.org/x/oauth2"
- )
- type OAuth2Handler struct {
- cfg *config.Config
- mu sync.RWMutex
- registeredStates map[string]*oauth2State
- registeredProviders map[string]*oauth2.Config
- }
- func NewOAuth2Handler(cfg *config.Config) *OAuth2Handler {
- h := &OAuth2Handler{
- cfg: cfg,
- }
- h.registeredStates = make(map[string]*oauth2State)
- h.registeredProviders = make(map[string]*oauth2.Config)
- for providerName, providerConfig := range cfg.AuthOAuth2Providers {
- completeProviderConfig(providerName, providerConfig)
- newConfig := &oauth2.Config{
- ClientID: providerConfig.ClientID,
- ClientSecret: providerConfig.ClientSecret,
- Scopes: providerConfig.Scopes,
- Endpoint: oauth2.Endpoint{
- AuthURL: providerConfig.AuthUrl,
- TokenURL: providerConfig.TokenUrl,
- },
- RedirectURL: cfg.AuthOAuth2RedirectURL,
- }
- h.registeredProviders[providerName] = newConfig
- log.Debugf("Dumping newly registered provider: %v = %+v", providerName, providerConfig)
- }
- return h
- }
- type oauth2State struct {
- providerConfig *oauth2.Config
- providerName string
- Username string
- Usergroup string
- }
- func assignIfEmpty(target *string, value string) {
- if *target == "" {
- *target = value
- }
- }
- func completeProviderConfig(providerName string, providerConfig *config.OAuth2Provider) {
- dbConfig, ok := oauth2ProviderDatabase[providerName]
- if ok {
- assignIfEmpty(&providerConfig.Name, dbConfig.Name)
- assignIfEmpty(&providerConfig.Title, dbConfig.Title)
- assignIfEmpty(&providerConfig.WhoamiUrl, dbConfig.WhoamiUrl)
- assignIfEmpty(&providerConfig.TokenUrl, dbConfig.TokenUrl)
- assignIfEmpty(&providerConfig.AuthUrl, dbConfig.AuthUrl)
- assignIfEmpty(&providerConfig.Icon, dbConfig.Icon)
- assignIfEmpty(&providerConfig.UsernameField, dbConfig.UsernameField)
- if providerConfig.Scopes == nil {
- providerConfig.Scopes = dbConfig.Scopes
- }
- } else {
- log.Warnf("Provider not found in database: %v", providerName)
- }
- }
- func (h *OAuth2Handler) getOAuth2Config(providerName string) (*oauth2.Config, error) {
- config, ok := h.registeredProviders[providerName]
- if !ok {
- return nil, fmt.Errorf("provider not found in config: %v", providerName)
- }
- return config, nil
- }
- func randString(nByte int) (string, error) {
- b := make([]byte, nByte)
- if _, err := io.ReadFull(rand.Reader, b); err != nil {
- return "", err
- }
- return base64.URLEncoding.EncodeToString(b), nil
- }
- func (h *OAuth2Handler) cookieSecure(r *http.Request) bool {
- useTLS := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
- return useTLS || h.cfg.Security.ForceSecureCookies
- }
- func (h *OAuth2Handler) setOAuthCallbackCookie(w http.ResponseWriter, r *http.Request, name, value string) {
- cookie := &http.Cookie{
- Name: name,
- Value: value,
- MaxAge: 900, // 15 minutes
- Secure: h.cookieSecure(r),
- HttpOnly: true,
- Path: "/",
- SameSite: http.SameSiteLaxMode,
- }
- http.SetCookie(w, cookie)
- }
- func (h *OAuth2Handler) HandleOAuthLogin(w http.ResponseWriter, r *http.Request) {
- state, err := randString(16)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
- providerName := r.URL.Query().Get("provider")
- provider, err := h.getOAuth2Config(providerName)
- if err != nil {
- log.Errorf("Failed to get provider config: %v %v", providerName, err)
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
- h.mu.Lock()
- h.registeredStates[state] = &oauth2State{
- providerConfig: provider,
- providerName: providerName,
- Username: "",
- }
- h.mu.Unlock()
- h.setOAuthCallbackCookie(w, r, "olivetin-sid-oauth", state)
- log.Infof("OAuth2 state: %v mapped to provider %v (found: %v), now redirecting", state, providerName, provider != nil)
- http.Redirect(w, r, provider.AuthCodeURL(state), http.StatusFound)
- }
- func (h *OAuth2Handler) validateStateMatch(queryState, cookieState string) bool {
- return queryState == cookieState
- }
- func (h *OAuth2Handler) checkOAuthCallbackCookie(w http.ResponseWriter, r *http.Request) (*oauth2State, string, bool) {
- cookie, err := r.Cookie("olivetin-sid-oauth")
- if err != nil {
- log.Errorf("Failed to get state cookie: %v", err)
- http.Error(w, "State not found", http.StatusBadRequest)
- return nil, "", false
- }
- state := cookie.Value
- if !h.validateStateMatch(r.URL.Query().Get("state"), state) {
- log.Errorf("State mismatch: %v != %v", r.URL.Query().Get("state"), state)
- http.Error(w, "State mismatch", http.StatusBadRequest)
- return nil, state, false
- }
- h.mu.RLock()
- registeredState, ok := h.registeredStates[state]
- h.mu.RUnlock()
- if !ok {
- log.Errorf("State not found in server: %v", state)
- http.Error(w, "State not found in server", http.StatusBadRequest)
- return nil, state, false
- }
- return registeredState, state, true
- }
- type HttpClientSettings struct {
- Transport *http.Transport
- Timeout time.Duration
- }
- func getOAuth2HttpClient(providerConfig *config.OAuth2Provider) *HttpClientSettings {
- config := &HttpClientSettings{
- Transport: &http.Transport{
- TLSClientConfig: &tls.Config{InsecureSkipVerify: providerConfig.InsecureSkipVerify},
- },
- Timeout: time.Duration(min(3, providerConfig.CallbackTimeout)) * time.Second,
- }
- if providerConfig.CertBundlePath != "" {
- config.Transport.TLSClientConfig.RootCAs = getOAuthCertBundle(providerConfig)
- }
- return config
- }
- func getOAuthCertBundle(providerConfig *config.OAuth2Provider) *x509.CertPool {
- caCert, err := os.ReadFile(providerConfig.CertBundlePath)
- if err != nil {
- log.Errorf("OAuth2 Cert Bundle - failed to read file: %v", err)
- return nil
- }
- caCertPool := x509.NewCertPool()
- if ok := caCertPool.AppendCertsFromPEM(caCert); !ok {
- log.Errorf("OAuth2 Cert Bundle - failed to append certificates from PEM")
- }
- return caCertPool
- }
- func (h *OAuth2Handler) exchangeOAuthCode(ctx context.Context, providerConfig *oauth2.Config, code string, clientSettings *HttpClientSettings) (*oauth2.Token, error) {
- exchangeClient := &http.Client{
- Transport: clientSettings.Transport,
- Timeout: clientSettings.Timeout,
- }
- ctx = context.WithValue(ctx, oauth2.HTTPClient, exchangeClient)
- return providerConfig.Exchange(ctx, code)
- }
- func (h *OAuth2Handler) createUserInfoClient(ctx context.Context, providerConfig *oauth2.Config, tok *oauth2.Token, clientSettings *HttpClientSettings) *http.Client {
- return &http.Client{
- Transport: &oauth2.Transport{
- Source: providerConfig.TokenSource(ctx, tok),
- Base: clientSettings.Transport,
- },
- Timeout: clientSettings.Timeout,
- }
- }
- func (h *OAuth2Handler) computeUsergroup(userinfo *UserInfo, providerConfig *config.OAuth2Provider) string {
- usergroup := userinfo.Usergroup
- if providerConfig != nil && providerConfig.AddToUsergroup != "" {
- if usergroup != "" {
- usergroup = usergroup + " " + providerConfig.AddToUsergroup
- } else {
- usergroup = providerConfig.AddToUsergroup
- }
- }
- return usergroup
- }
- func (h *OAuth2Handler) HandleOAuthCallback(w http.ResponseWriter, r *http.Request) {
- log.Infof("OAuth2 Callback received")
- registeredState, state, ok := h.checkOAuthCallbackCookie(w, r)
- if !ok {
- return
- }
- code := r.FormValue("code")
- log.WithFields(log.Fields{
- "state": state,
- "token-code": code,
- }).Debug("OAuth2 Token Code")
- providerConfig := h.cfg.AuthOAuth2Providers[registeredState.providerName]
- clientSettings := getOAuth2HttpClient(providerConfig)
- ctx := context.Background()
- tok, err := h.exchangeOAuthCode(ctx, registeredState.providerConfig, code, clientSettings)
- if err != nil {
- log.Errorf("Failed to exchange code: %v", err)
- http.Error(w, "Failed to exchange code", http.StatusBadRequest)
- return
- }
- userInfoClient := h.createUserInfoClient(ctx, registeredState.providerConfig, tok, clientSettings)
- userinfo := getUserInfo(h.cfg, userInfoClient, providerConfig)
- h.mu.Lock()
- h.registeredStates[state].Username = userinfo.Username
- h.registeredStates[state].Usergroup = h.computeUsergroup(userinfo, providerConfig)
- h.mu.Unlock()
- http.Redirect(w, r, "/", http.StatusFound)
- }
- type UserInfo struct {
- Username string
- Usergroup string
- }
- //gocyclo:ignore
- func getUserInfo(cfg *config.Config, client *http.Client, provider *config.OAuth2Provider) *UserInfo {
- ret := &UserInfo{}
- res, err := client.Get(provider.WhoamiUrl)
- if err != nil {
- log.Errorf("Failed to get user data: %v", err)
- return ret
- }
- if res.StatusCode != http.StatusOK {
- log.Errorf("Failed to get user data: %v", res.StatusCode)
- return ret
- }
- defer res.Body.Close()
- contents, err := io.ReadAll(res.Body)
- if err != nil {
- log.Errorf("Failed to read user data: %v", err)
- return ret
- }
- var userData map[string]any
- if cfg.InsecureAllowDumpOAuth2UserData {
- log.Debugf("OAuth2 User Data: %v+", string(contents))
- }
- err = json.Unmarshal([]byte(contents), &userData)
- if err != nil {
- log.Errorf("Failed to unmarshal user data: %v", err)
- return ret
- }
- ret.Username = getDataField(userData, provider.UsernameField)
- ret.Usergroup = getDataField(userData, provider.UserGroupField)
- return ret
- }
- func getDataField(data map[string]any, field string) string {
- if field == "" {
- return ""
- }
- val, ok := data[field]
- if !ok {
- log.Errorf("Failed to get field from user data: %v / %v", data, field)
- return ""
- }
- stringVal, ok := val.(string)
- if !ok {
- log.Errorf("Field %v is not a string: %v", field, val)
- return ""
- }
- return stringVal
- }
- func (h *OAuth2Handler) lookupOAuth2UserByState(state string) (*authTypes.AuthenticatedUser, bool) {
- h.mu.RLock()
- serverState, found := h.registeredStates[state]
- if !found {
- h.mu.RUnlock()
- return nil, false
- }
- user := &authTypes.AuthenticatedUser{
- Username: serverState.Username,
- UsergroupLine: serverState.Usergroup,
- Provider: "oauth2",
- SID: state,
- }
- h.mu.RUnlock()
- return user, true
- }
- func (h *OAuth2Handler) RevokeSession(sid string) {
- h.mu.Lock()
- defer h.mu.Unlock()
- delete(h.registeredStates, sid)
- }
- func (h *OAuth2Handler) CheckUserFromOAuth2Cookie(context *authTypes.AuthCheckingContext) *authTypes.AuthenticatedUser {
- cookie, err := context.Request.Cookie("olivetin-sid-oauth")
- if err != nil || cookie.Value == "" {
- return nil
- }
- user, found := h.lookupOAuth2UserByState(cookie.Value)
- if !found {
- log.WithFields(log.Fields{
- "sid": cookie.Value,
- "provider": "oauth2",
- }).Warnf("Stale session")
- return nil
- }
- return user
- }
|