restapi_auth_oauth2.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. package otoauth2
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "crypto/tls"
  6. "crypto/x509"
  7. "encoding/base64"
  8. "encoding/json"
  9. "fmt"
  10. "io"
  11. "net/http"
  12. "os"
  13. "sync"
  14. "time"
  15. authTypes "github.com/OliveTin/OliveTin/internal/auth/authpublic"
  16. config "github.com/OliveTin/OliveTin/internal/config"
  17. log "github.com/sirupsen/logrus"
  18. "golang.org/x/oauth2"
  19. )
  20. type OAuth2Handler struct {
  21. cfg *config.Config
  22. mu sync.RWMutex
  23. registeredStates map[string]*oauth2State
  24. registeredProviders map[string]*oauth2.Config
  25. }
  26. func NewOAuth2Handler(cfg *config.Config) *OAuth2Handler {
  27. h := &OAuth2Handler{
  28. cfg: cfg,
  29. }
  30. h.registeredStates = make(map[string]*oauth2State)
  31. h.registeredProviders = make(map[string]*oauth2.Config)
  32. for providerName, providerConfig := range cfg.AuthOAuth2Providers {
  33. completeProviderConfig(providerName, providerConfig)
  34. newConfig := &oauth2.Config{
  35. ClientID: providerConfig.ClientID,
  36. ClientSecret: providerConfig.ClientSecret,
  37. Scopes: providerConfig.Scopes,
  38. Endpoint: oauth2.Endpoint{
  39. AuthURL: providerConfig.AuthUrl,
  40. TokenURL: providerConfig.TokenUrl,
  41. },
  42. RedirectURL: cfg.AuthOAuth2RedirectURL,
  43. }
  44. h.registeredProviders[providerName] = newConfig
  45. log.Debugf("Dumping newly registered provider: %v = %+v", providerName, providerConfig)
  46. }
  47. return h
  48. }
  49. type oauth2State struct {
  50. providerConfig *oauth2.Config
  51. providerName string
  52. Username string
  53. Usergroup string
  54. }
  55. func assignIfEmpty(target *string, value string) {
  56. if *target == "" {
  57. *target = value
  58. }
  59. }
  60. func completeProviderConfig(providerName string, providerConfig *config.OAuth2Provider) {
  61. dbConfig, ok := oauth2ProviderDatabase[providerName]
  62. if ok {
  63. assignIfEmpty(&providerConfig.Name, dbConfig.Name)
  64. assignIfEmpty(&providerConfig.Title, dbConfig.Title)
  65. assignIfEmpty(&providerConfig.WhoamiUrl, dbConfig.WhoamiUrl)
  66. assignIfEmpty(&providerConfig.TokenUrl, dbConfig.TokenUrl)
  67. assignIfEmpty(&providerConfig.AuthUrl, dbConfig.AuthUrl)
  68. assignIfEmpty(&providerConfig.Icon, dbConfig.Icon)
  69. assignIfEmpty(&providerConfig.UsernameField, dbConfig.UsernameField)
  70. if providerConfig.Scopes == nil {
  71. providerConfig.Scopes = dbConfig.Scopes
  72. }
  73. } else {
  74. log.Warnf("Provider not found in database: %v", providerName)
  75. }
  76. }
  77. func (h *OAuth2Handler) getOAuth2Config(providerName string) (*oauth2.Config, error) {
  78. config, ok := h.registeredProviders[providerName]
  79. if !ok {
  80. return nil, fmt.Errorf("provider not found in config: %v", providerName)
  81. }
  82. return config, nil
  83. }
  84. func randString(nByte int) (string, error) {
  85. b := make([]byte, nByte)
  86. if _, err := io.ReadFull(rand.Reader, b); err != nil {
  87. return "", err
  88. }
  89. return base64.URLEncoding.EncodeToString(b), nil
  90. }
  91. func (h *OAuth2Handler) cookieSecure(r *http.Request) bool {
  92. useTLS := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
  93. return useTLS || h.cfg.Security.ForceSecureCookies
  94. }
  95. func (h *OAuth2Handler) setOAuthCallbackCookie(w http.ResponseWriter, r *http.Request, name, value string) {
  96. cookie := &http.Cookie{
  97. Name: name,
  98. Value: value,
  99. MaxAge: 900, // 15 minutes
  100. Secure: h.cookieSecure(r),
  101. HttpOnly: true,
  102. Path: "/",
  103. SameSite: http.SameSiteLaxMode,
  104. }
  105. http.SetCookie(w, cookie)
  106. }
  107. func (h *OAuth2Handler) HandleOAuthLogin(w http.ResponseWriter, r *http.Request) {
  108. state, err := randString(16)
  109. if err != nil {
  110. http.Error(w, err.Error(), http.StatusInternalServerError)
  111. return
  112. }
  113. providerName := r.URL.Query().Get("provider")
  114. provider, err := h.getOAuth2Config(providerName)
  115. if err != nil {
  116. log.Errorf("Failed to get provider config: %v %v", providerName, err)
  117. http.Error(w, err.Error(), http.StatusBadRequest)
  118. return
  119. }
  120. h.mu.Lock()
  121. h.registeredStates[state] = &oauth2State{
  122. providerConfig: provider,
  123. providerName: providerName,
  124. Username: "",
  125. }
  126. h.mu.Unlock()
  127. h.setOAuthCallbackCookie(w, r, "olivetin-sid-oauth", state)
  128. log.Infof("OAuth2 state: %v mapped to provider %v (found: %v), now redirecting", state, providerName, provider != nil)
  129. http.Redirect(w, r, provider.AuthCodeURL(state), http.StatusFound)
  130. }
  131. func (h *OAuth2Handler) validateStateMatch(queryState, cookieState string) bool {
  132. return queryState == cookieState
  133. }
  134. func (h *OAuth2Handler) checkOAuthCallbackCookie(w http.ResponseWriter, r *http.Request) (*oauth2State, string, bool) {
  135. cookie, err := r.Cookie("olivetin-sid-oauth")
  136. if err != nil {
  137. log.Errorf("Failed to get state cookie: %v", err)
  138. http.Error(w, "State not found", http.StatusBadRequest)
  139. return nil, "", false
  140. }
  141. state := cookie.Value
  142. if !h.validateStateMatch(r.URL.Query().Get("state"), state) {
  143. log.Errorf("State mismatch: %v != %v", r.URL.Query().Get("state"), state)
  144. http.Error(w, "State mismatch", http.StatusBadRequest)
  145. return nil, state, false
  146. }
  147. h.mu.RLock()
  148. registeredState, ok := h.registeredStates[state]
  149. h.mu.RUnlock()
  150. if !ok {
  151. log.Errorf("State not found in server: %v", state)
  152. http.Error(w, "State not found in server", http.StatusBadRequest)
  153. return nil, state, false
  154. }
  155. return registeredState, state, true
  156. }
  157. type HttpClientSettings struct {
  158. Transport *http.Transport
  159. Timeout time.Duration
  160. }
  161. func getOAuth2HttpClient(providerConfig *config.OAuth2Provider) *HttpClientSettings {
  162. config := &HttpClientSettings{
  163. Transport: &http.Transport{
  164. TLSClientConfig: &tls.Config{InsecureSkipVerify: providerConfig.InsecureSkipVerify},
  165. },
  166. Timeout: time.Duration(min(3, providerConfig.CallbackTimeout)) * time.Second,
  167. }
  168. if providerConfig.CertBundlePath != "" {
  169. config.Transport.TLSClientConfig.RootCAs = getOAuthCertBundle(providerConfig)
  170. }
  171. return config
  172. }
  173. func getOAuthCertBundle(providerConfig *config.OAuth2Provider) *x509.CertPool {
  174. caCert, err := os.ReadFile(providerConfig.CertBundlePath)
  175. if err != nil {
  176. log.Errorf("OAuth2 Cert Bundle - failed to read file: %v", err)
  177. return nil
  178. }
  179. caCertPool := x509.NewCertPool()
  180. if ok := caCertPool.AppendCertsFromPEM(caCert); !ok {
  181. log.Errorf("OAuth2 Cert Bundle - failed to append certificates from PEM")
  182. }
  183. return caCertPool
  184. }
  185. func (h *OAuth2Handler) exchangeOAuthCode(ctx context.Context, providerConfig *oauth2.Config, code string, clientSettings *HttpClientSettings) (*oauth2.Token, error) {
  186. exchangeClient := &http.Client{
  187. Transport: clientSettings.Transport,
  188. Timeout: clientSettings.Timeout,
  189. }
  190. ctx = context.WithValue(ctx, oauth2.HTTPClient, exchangeClient)
  191. return providerConfig.Exchange(ctx, code)
  192. }
  193. func (h *OAuth2Handler) createUserInfoClient(ctx context.Context, providerConfig *oauth2.Config, tok *oauth2.Token, clientSettings *HttpClientSettings) *http.Client {
  194. return &http.Client{
  195. Transport: &oauth2.Transport{
  196. Source: providerConfig.TokenSource(ctx, tok),
  197. Base: clientSettings.Transport,
  198. },
  199. Timeout: clientSettings.Timeout,
  200. }
  201. }
  202. func (h *OAuth2Handler) computeUsergroup(userinfo *UserInfo, providerConfig *config.OAuth2Provider) string {
  203. usergroup := userinfo.Usergroup
  204. if providerConfig != nil && providerConfig.AddToUsergroup != "" {
  205. if usergroup != "" {
  206. usergroup = usergroup + " " + providerConfig.AddToUsergroup
  207. } else {
  208. usergroup = providerConfig.AddToUsergroup
  209. }
  210. }
  211. return usergroup
  212. }
  213. func (h *OAuth2Handler) HandleOAuthCallback(w http.ResponseWriter, r *http.Request) {
  214. log.Infof("OAuth2 Callback received")
  215. registeredState, state, ok := h.checkOAuthCallbackCookie(w, r)
  216. if !ok {
  217. return
  218. }
  219. code := r.FormValue("code")
  220. log.WithFields(log.Fields{
  221. "state": state,
  222. "token-code": code,
  223. }).Debug("OAuth2 Token Code")
  224. providerConfig := h.cfg.AuthOAuth2Providers[registeredState.providerName]
  225. clientSettings := getOAuth2HttpClient(providerConfig)
  226. ctx := context.Background()
  227. tok, err := h.exchangeOAuthCode(ctx, registeredState.providerConfig, code, clientSettings)
  228. if err != nil {
  229. log.Errorf("Failed to exchange code: %v", err)
  230. http.Error(w, "Failed to exchange code", http.StatusBadRequest)
  231. return
  232. }
  233. userInfoClient := h.createUserInfoClient(ctx, registeredState.providerConfig, tok, clientSettings)
  234. userinfo := getUserInfo(h.cfg, userInfoClient, providerConfig)
  235. h.mu.Lock()
  236. h.registeredStates[state].Username = userinfo.Username
  237. h.registeredStates[state].Usergroup = h.computeUsergroup(userinfo, providerConfig)
  238. h.mu.Unlock()
  239. http.Redirect(w, r, "/", http.StatusFound)
  240. }
  241. type UserInfo struct {
  242. Username string
  243. Usergroup string
  244. }
  245. //gocyclo:ignore
  246. func getUserInfo(cfg *config.Config, client *http.Client, provider *config.OAuth2Provider) *UserInfo {
  247. ret := &UserInfo{}
  248. res, err := client.Get(provider.WhoamiUrl)
  249. if err != nil {
  250. log.Errorf("Failed to get user data: %v", err)
  251. return ret
  252. }
  253. if res.StatusCode != http.StatusOK {
  254. log.Errorf("Failed to get user data: %v", res.StatusCode)
  255. return ret
  256. }
  257. defer res.Body.Close()
  258. contents, err := io.ReadAll(res.Body)
  259. if err != nil {
  260. log.Errorf("Failed to read user data: %v", err)
  261. return ret
  262. }
  263. var userData map[string]any
  264. if cfg.InsecureAllowDumpOAuth2UserData {
  265. log.Debugf("OAuth2 User Data: %v+", string(contents))
  266. }
  267. err = json.Unmarshal([]byte(contents), &userData)
  268. if err != nil {
  269. log.Errorf("Failed to unmarshal user data: %v", err)
  270. return ret
  271. }
  272. ret.Username = getDataField(userData, provider.UsernameField)
  273. ret.Usergroup = getDataField(userData, provider.UserGroupField)
  274. return ret
  275. }
  276. func getDataField(data map[string]any, field string) string {
  277. if field == "" {
  278. return ""
  279. }
  280. val, ok := data[field]
  281. if !ok {
  282. log.Errorf("Failed to get field from user data: %v / %v", data, field)
  283. return ""
  284. }
  285. stringVal, ok := val.(string)
  286. if !ok {
  287. log.Errorf("Field %v is not a string: %v", field, val)
  288. return ""
  289. }
  290. return stringVal
  291. }
  292. func (h *OAuth2Handler) lookupOAuth2UserByState(state string) (*authTypes.AuthenticatedUser, bool) {
  293. h.mu.RLock()
  294. serverState, found := h.registeredStates[state]
  295. if !found {
  296. h.mu.RUnlock()
  297. return nil, false
  298. }
  299. user := &authTypes.AuthenticatedUser{
  300. Username: serverState.Username,
  301. UsergroupLine: serverState.Usergroup,
  302. Provider: "oauth2",
  303. SID: state,
  304. }
  305. h.mu.RUnlock()
  306. return user, true
  307. }
  308. func (h *OAuth2Handler) RevokeSession(sid string) {
  309. h.mu.Lock()
  310. defer h.mu.Unlock()
  311. delete(h.registeredStates, sid)
  312. }
  313. func (h *OAuth2Handler) CheckUserFromOAuth2Cookie(context *authTypes.AuthCheckingContext) *authTypes.AuthenticatedUser {
  314. cookie, err := context.Request.Cookie("olivetin-sid-oauth")
  315. if err != nil || cookie.Value == "" {
  316. return nil
  317. }
  318. user, found := h.lookupOAuth2UserByState(cookie.Value)
  319. if !found {
  320. log.WithFields(log.Fields{
  321. "sid": cookie.Value,
  322. "provider": "oauth2",
  323. }).Warnf("Stale session")
  324. return nil
  325. }
  326. return user
  327. }