jwt_test.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package otjwt
  2. import (
  3. "crypto/rand"
  4. "crypto/rsa"
  5. "crypto/x509"
  6. "encoding/pem"
  7. "io"
  8. "net/http"
  9. "net/http/httptest"
  10. "os"
  11. "testing"
  12. "time"
  13. "github.com/OliveTin/OliveTin/internal/auth/authpublic"
  14. config "github.com/OliveTin/OliveTin/internal/config"
  15. "github.com/golang-jwt/jwt/v5"
  16. "github.com/stretchr/testify/assert"
  17. )
  18. func generateRSAKeyPair(t *testing.T) (*rsa.PrivateKey, []byte) {
  19. privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
  20. if err != nil {
  21. t.Fatalf("failed to generate RSA key: %v", err)
  22. }
  23. pubKey := &privateKey.PublicKey
  24. pkixPubKey, err := x509.MarshalPKIXPublicKey(pubKey)
  25. if err != nil {
  26. t.Fatalf("failed to marshal public key: %v", err)
  27. }
  28. pubPem := pem.EncodeToMemory(
  29. &pem.Block{
  30. Type: "PUBLIC KEY",
  31. Bytes: pkixPubKey,
  32. },
  33. )
  34. return privateKey, pubPem
  35. }
  36. func createKeys(t *testing.T) (*rsa.PrivateKey, string) {
  37. tmpFile, err := os.CreateTemp(os.TempDir(), "olivetin-jwt-")
  38. if err != nil {
  39. t.Fatalf("failed to create temp file: %v", err)
  40. }
  41. defer tmpFile.Close()
  42. t.Logf("Created File: %s", tmpFile.Name())
  43. privateKey, pubPem := generateRSAKeyPair(t)
  44. if err := os.WriteFile(tmpFile.Name(), pubPem, 0644); err != nil {
  45. t.Fatalf("error when dumping pubKey: %s \n", err)
  46. }
  47. return privateKey, tmpFile.Name()
  48. }
  49. func newMux() *http.ServeMux {
  50. mux := http.NewServeMux()
  51. return mux
  52. }
  53. func createJWTTokenWithExpiration(t *testing.T, privateKey *rsa.PrivateKey, expire int64) string {
  54. return createJWTTokenWithExpirationAndAudience(t, privateKey, expire, "")
  55. }
  56. func createJWTTokenWithExpirationAndAudience(t *testing.T, privateKey *rsa.PrivateKey, expire int64, audience string) string {
  57. token := jwt.New(jwt.SigningMethodRS256)
  58. claims := token.Claims.(jwt.MapClaims)
  59. claims["nbf"] = time.Now().Unix() - 1000
  60. claims["exp"] = time.Now().Unix() + expire
  61. claims["sub"] = "test"
  62. claims["olivetinGroup"] = "test"
  63. if audience != "" {
  64. claims["aud"] = audience
  65. }
  66. tokenStr, err := token.SignedString(privateKey)
  67. if err != nil {
  68. t.Fatalf("failed to sign JWT token: %v", err)
  69. }
  70. return tokenStr
  71. }
  72. func setupJWTTestHandler(t *testing.T, cfg *config.Config) http.Handler {
  73. mux := newMux()
  74. mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  75. context := &authpublic.AuthCheckingContext{
  76. Request: r,
  77. Config: cfg,
  78. }
  79. user := CheckUserFromJwtHeader(context)
  80. if user == nil {
  81. w.WriteHeader(403)
  82. return
  83. }
  84. assert.Equal(t, "test", user.Username)
  85. assert.Equal(t, "test", user.UsergroupLine)
  86. })
  87. return mux
  88. }
  89. func verifyJWTResponse(t *testing.T, res *http.Response, expectCode int) {
  90. defer res.Body.Close()
  91. assert.Equal(t, expectCode, res.StatusCode)
  92. body, _ := io.ReadAll(res.Body)
  93. t.Logf("Response body: %s", string(body))
  94. }
  95. func testJwkValidation(t *testing.T, expire int64, expectCode int) {
  96. testJwkValidationWithAudience(t, expire, expectCode, "", "")
  97. }
  98. func testJwkValidationWithAudience(t *testing.T, expire int64, expectCode int, configAudience, tokenAudience string) {
  99. privateKey, publicKeyPath := createKeys(t)
  100. defer os.Remove(publicKeyPath)
  101. cfg := config.DefaultConfig()
  102. cfg.AuthJwtPubKeyPath = publicKeyPath
  103. cfg.AuthJwtClaimUsername = "sub"
  104. cfg.AuthJwtClaimUserGroup = "olivetinGroup"
  105. cfg.AuthJwtHeader = "Authorization"
  106. cfg.AuthJwtAud = configAudience
  107. tokenStr := createJWTTokenWithExpirationAndAudience(t, privateKey, expire, tokenAudience)
  108. handler := setupJWTTestHandler(t, cfg)
  109. srv := httptest.NewServer(handler)
  110. defer srv.Close()
  111. res := makeJWTRequest(t, srv, tokenStr)
  112. verifyJWTResponse(t, res, expectCode)
  113. }
  114. func TestJWTSignatureVerificationSucceeds(t *testing.T) {
  115. testJwkValidation(t, 1000, 200)
  116. }
  117. func TestJWTSignatureVerificationFails(t *testing.T) {
  118. testJwkValidation(t, -500, 403)
  119. }
  120. func TestJWTAudienceValidationRejectsWrongAudience(t *testing.T) {
  121. testJwkValidationWithAudience(t, 1000, 403, "expected-audience", "wrong-audience")
  122. }
  123. func TestJWTAudienceValidationAcceptsCorrectAudience(t *testing.T) {
  124. testJwkValidationWithAudience(t, 1000, 200, "expected-audience", "expected-audience")
  125. }
  126. func createJWTTokenWithGroups(t *testing.T, privateKey *rsa.PrivateKey, groups interface{}) string {
  127. token := jwt.New(jwt.SigningMethodRS256)
  128. claims := token.Claims.(jwt.MapClaims)
  129. claims["nbf"] = time.Now().Unix() - 1000
  130. claims["exp"] = time.Now().Unix() + 2000
  131. claims["sub"] = "test"
  132. claims["olivetinGroup"] = groups
  133. tokenStr, err := token.SignedString(privateKey)
  134. if err != nil {
  135. t.Fatalf("failed to sign JWT token: %v", err)
  136. }
  137. return tokenStr
  138. }
  139. func makeJWTRequest(t *testing.T, srv *httptest.Server, tokenStr string) *http.Response {
  140. req, err := http.NewRequest("GET", srv.URL, nil)
  141. if err != nil {
  142. t.Fatalf("failed to create request: %v", err)
  143. }
  144. req.Header.Set("Authorization", "Bearer "+tokenStr)
  145. res, err := http.DefaultClient.Do(req)
  146. if err != nil {
  147. t.Fatalf("Client err: %+v", err)
  148. }
  149. return res
  150. }
  151. func TestJWTHeader(t *testing.T) {
  152. privateKey, publicKeyPath := createKeys(t)
  153. defer os.Remove(publicKeyPath)
  154. cfg := config.DefaultConfig()
  155. cfg.AuthJwtPubKeyPath = publicKeyPath
  156. cfg.AuthJwtClaimUsername = "sub"
  157. cfg.AuthJwtClaimUserGroup = "olivetinGroup"
  158. cfg.AuthJwtHeader = "Authorization"
  159. tokenStr := createJWTTokenWithGroups(t, privateKey, []string{"test", "test2"})
  160. mux := newMux()
  161. mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  162. context := &authpublic.AuthCheckingContext{
  163. Request: r,
  164. Config: cfg,
  165. }
  166. user := CheckUserFromJwtHeader(context)
  167. if user == nil {
  168. w.WriteHeader(403)
  169. return
  170. }
  171. assert.Equal(t, "test", user.Username)
  172. assert.Equal(t, "test test2", user.UsergroupLine)
  173. })
  174. srv := httptest.NewServer(mux)
  175. defer srv.Close()
  176. res := makeJWTRequest(t, srv, tokenStr)
  177. defer res.Body.Close()
  178. assert.Equal(t, 200, res.StatusCode)
  179. body, _ := io.ReadAll(res.Body)
  180. t.Logf("Response body: %s", string(body))
  181. }