jwt.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. package otjwt
  2. import (
  3. "context"
  4. "crypto/rsa"
  5. "errors"
  6. "fmt"
  7. "os"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/MicahParks/keyfunc/v3"
  12. authTypes "github.com/OliveTin/OliveTin/internal/auth/authpublic"
  13. "github.com/OliveTin/OliveTin/internal/config"
  14. "github.com/golang-jwt/jwt/v5"
  15. log "github.com/sirupsen/logrus"
  16. )
  17. func parseJwtToken(cfg *config.Config, jwtString string) (*jwt.Token, error) {
  18. if cfg.AuthJwtCertsURL != "" {
  19. return parseJwtTokenWithRemoteKey(cfg, jwtString)
  20. }
  21. if cfg.AuthJwtPubKeyPath != "" {
  22. return parseJwtTokenWithLocalKey(cfg, jwtString)
  23. }
  24. if cfg.AuthJwtHmacSecret == "" {
  25. return nil, errors.New("no JWT authentication method configured")
  26. }
  27. return parseJwtTokenWithHMAC(cfg, jwtString)
  28. }
  29. func parserOptionsWithAudience(cfg *config.Config) []jwt.ParserOption {
  30. if cfg.AuthJwtAud == "" {
  31. return nil
  32. }
  33. return []jwt.ParserOption{jwt.WithAudience(cfg.AuthJwtAud)}
  34. }
  35. func getClaimsFromJwtToken(cfg *config.Config, jwtString string) (jwt.MapClaims, error) {
  36. token, err := parseJwtToken(cfg, jwtString)
  37. if err != nil {
  38. log.Errorf("jwt parse failure: %v", err)
  39. return nil, errors.New("jwt parse failure")
  40. }
  41. if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
  42. return claims, nil
  43. } else {
  44. return nil, errors.New("jwt token isn't valid")
  45. }
  46. }
  47. func parseJwtTokenWithRemoteKey(cfg *config.Config, jwtToken string) (*jwt.Token, error) {
  48. err := initJwks(cfg)
  49. if err != nil {
  50. log.Errorf("jwt init JWKS failure: %v", err)
  51. return nil, err
  52. }
  53. opts := parserOptionsWithAudience(cfg)
  54. return jwt.Parse(jwtToken, jwksVerifier.Keyfunc, opts...)
  55. }
  56. var (
  57. pubKeyBytes []byte = nil
  58. pubKey *rsa.PublicKey
  59. loadedKeyPath string
  60. jwksVerifier keyfunc.Keyfunc
  61. jwksOnce sync.Once
  62. jwksInitErr error
  63. localKeyMutex sync.RWMutex
  64. localKeyInitErr error
  65. )
  66. func initJwks(cfg *config.Config) error {
  67. jwksOnce.Do(func() {
  68. if cfg.AuthJwtCertsURL != "" {
  69. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
  70. defer cancel()
  71. var err error
  72. jwksVerifier, err = keyfunc.NewDefaultCtx(ctx, []string{
  73. cfg.AuthJwtCertsURL,
  74. })
  75. if err != nil {
  76. log.Errorf("Init JWKS Failure: %v", err)
  77. jwksInitErr = err
  78. }
  79. }
  80. })
  81. return jwksInitErr
  82. }
  83. func loadPublicKeyFromFile(keyPath string) error {
  84. keyBytes, err := os.ReadFile(keyPath)
  85. if err != nil {
  86. return fmt.Errorf("couldn't read public key from file %s", keyPath)
  87. }
  88. parsedKey, err := jwt.ParseRSAPublicKeyFromPEM(keyBytes)
  89. if err != nil {
  90. return fmt.Errorf("error parsing public key object (from %s)", keyPath)
  91. }
  92. pubKeyBytes = keyBytes
  93. pubKey = parsedKey
  94. loadedKeyPath = keyPath
  95. localKeyInitErr = nil
  96. return nil
  97. }
  98. func isKeyLoadedForPath(keyPath string) bool {
  99. return pubKeyBytes != nil && loadedKeyPath == keyPath
  100. }
  101. func readLocalPublicKeyWithLock(keyPath string) error {
  102. localKeyMutex.RLock()
  103. alreadyLoaded := isKeyLoadedForPath(keyPath)
  104. localKeyMutex.RUnlock()
  105. if alreadyLoaded {
  106. return nil
  107. }
  108. localKeyMutex.Lock()
  109. defer localKeyMutex.Unlock()
  110. if isKeyLoadedForPath(keyPath) {
  111. return nil
  112. }
  113. localKeyInitErr = loadPublicKeyFromFile(keyPath)
  114. return localKeyInitErr
  115. }
  116. func readLocalPublicKey(cfg *config.Config) error {
  117. if cfg.AuthJwtPubKeyPath == "" {
  118. return errors.New("no JWT public key path configured")
  119. }
  120. return readLocalPublicKeyWithLock(cfg.AuthJwtPubKeyPath)
  121. }
  122. func parseJwtTokenWithLocalKey(cfg *config.Config, jwtString string) (*jwt.Token, error) {
  123. err := readLocalPublicKey(cfg)
  124. if err != nil {
  125. return nil, err
  126. }
  127. keyFunc := func(token *jwt.Token) (interface{}, error) {
  128. if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
  129. return nil, fmt.Errorf("parseJwt expected token algorithm RSA but got: %v", token.Header["alg"])
  130. }
  131. return pubKey, nil
  132. }
  133. opts := parserOptionsWithAudience(cfg)
  134. return jwt.Parse(jwtString, keyFunc, opts...)
  135. }
  136. // Hash-based Message Authentication Code
  137. func parseJwtTokenWithHMAC(cfg *config.Config, jwtString string) (*jwt.Token, error) {
  138. keyFunc := func(token *jwt.Token) (interface{}, error) {
  139. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  140. return nil, fmt.Errorf("parseJwt expected token algorithm HMAC but got: %v", token.Header["alg"])
  141. }
  142. return []byte(cfg.AuthJwtHmacSecret), nil
  143. }
  144. opts := parserOptionsWithAudience(cfg)
  145. return jwt.Parse(jwtString, keyFunc, opts...)
  146. }
  147. func lookupClaimValueOrDefault(claims jwt.MapClaims, key string, def string) string {
  148. if val, ok := claims[key]; ok {
  149. return fmt.Sprintf("%s", val)
  150. } else {
  151. return def
  152. }
  153. }
  154. func CheckUserFromJwtCookie(context *authTypes.AuthCheckingContext) *authTypes.AuthenticatedUser {
  155. cookie, err := context.Request.Cookie(context.Config.AuthJwtCookieName)
  156. if err != nil {
  157. log.Debugf("jwt cookie check %v name: %v", err, context.Config.AuthJwtCookieName)
  158. return nil
  159. }
  160. return parseJwt(context.Config, cookie.Value)
  161. }
  162. func CheckUserFromJwtHeader(context *authTypes.AuthCheckingContext) *authTypes.AuthenticatedUser {
  163. header := context.Request.Header.Get(context.Config.AuthJwtHeader)
  164. if header == "" {
  165. return nil
  166. }
  167. token := strings.TrimPrefix(header, "Bearer ")
  168. token = strings.TrimSpace(token)
  169. return parseJwt(context.Config, token)
  170. }
  171. func parseJwt(cfg *config.Config, token string) *authTypes.AuthenticatedUser {
  172. claims, err := getClaimsFromJwtToken(cfg, token)
  173. if err != nil {
  174. log.Warnf("jwt claim error: %+v", err)
  175. return nil
  176. }
  177. if cfg.InsecureAllowDumpJwtClaims {
  178. log.Debugf("JWT Claims %+v", claims)
  179. }
  180. user := &authTypes.AuthenticatedUser{
  181. Username: lookupClaimValueOrDefault(claims, cfg.AuthJwtClaimUsername, ""),
  182. UsergroupLine: parseGroupClaim(cfg.AuthJwtClaimUserGroup, claims),
  183. Provider: "jwt",
  184. }
  185. return user
  186. }
  187. func parseGroupClaim(groupClaim string, claims jwt.MapClaims) string {
  188. usergroup := ""
  189. if val, ok := claims[groupClaim]; ok {
  190. if array, ok := val.([]interface{}); ok {
  191. groups := make([]string, len(array))
  192. for i, v := range array {
  193. groups[i] = fmt.Sprintf("%s", v)
  194. }
  195. usergroup = strings.Join(groups, " ")
  196. } else {
  197. usergroup = fmt.Sprintf("%s", val)
  198. }
  199. }
  200. return usergroup
  201. }