|
|
@@ -0,0 +1,429 @@
|
|
|
+// SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
|
|
|
+// SPDX-License-Identifier: Apache-2.0
|
|
|
+
|
|
|
+package model
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "database/sql"
|
|
|
+ "encoding/json"
|
|
|
+ "testing"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/go-webauthn/webauthn/webauthn"
|
|
|
+)
|
|
|
+
|
|
|
+func TestNewWebSession(t *testing.T) {
|
|
|
+ const userAgent = "test-agent"
|
|
|
+ const ip = "127.0.0.1"
|
|
|
+
|
|
|
+ session, secret := NewWebSession(userAgent, ip)
|
|
|
+
|
|
|
+ if session == nil {
|
|
|
+ t.Fatal("NewWebSession returned a nil session")
|
|
|
+ }
|
|
|
+ if secret == "" {
|
|
|
+ t.Error("NewWebSession returned an empty secret")
|
|
|
+ }
|
|
|
+ if session.ID == "" {
|
|
|
+ t.Error("NewWebSession produced an empty ID")
|
|
|
+ }
|
|
|
+ if session.ID == secret {
|
|
|
+ t.Error("session ID and secret must not be equal")
|
|
|
+ }
|
|
|
+ if len(session.SecretHash) == 0 {
|
|
|
+ t.Error("NewWebSession produced an empty SecretHash")
|
|
|
+ }
|
|
|
+ if session.CSRF() == "" {
|
|
|
+ t.Error("NewWebSession produced an empty CSRF token")
|
|
|
+ }
|
|
|
+ if session.UserAgent != userAgent {
|
|
|
+ t.Errorf("UserAgent = %q, want %q", session.UserAgent, userAgent)
|
|
|
+ }
|
|
|
+ if session.IP != ip {
|
|
|
+ t.Errorf("IP = %q, want %q", session.IP, ip)
|
|
|
+ }
|
|
|
+ if session.IsAuthenticated() {
|
|
|
+ t.Error("a fresh session must not be authenticated")
|
|
|
+ }
|
|
|
+ if session.IsDirty() {
|
|
|
+ t.Error("a fresh session must not be dirty")
|
|
|
+ }
|
|
|
+ if !session.VerifySecret(secret) {
|
|
|
+ t.Error("VerifySecret rejected the secret returned by NewWebSession")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestNewWebSession_ProducesUniqueIdentities(t *testing.T) {
|
|
|
+ s1, secret1 := NewWebSession("", "")
|
|
|
+ s2, secret2 := NewWebSession("", "")
|
|
|
+
|
|
|
+ if s1.ID == s2.ID {
|
|
|
+ t.Error("successive NewWebSession calls produced the same ID")
|
|
|
+ }
|
|
|
+ if secret1 == secret2 {
|
|
|
+ t.Error("successive NewWebSession calls produced the same secret")
|
|
|
+ }
|
|
|
+ if bytes.Equal(s1.SecretHash, s2.SecretHash) {
|
|
|
+ t.Error("successive NewWebSession calls produced the same SecretHash")
|
|
|
+ }
|
|
|
+ if s1.CSRF() == s2.CSRF() {
|
|
|
+ t.Error("successive NewWebSession calls produced the same CSRF token")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_Rotate(t *testing.T) {
|
|
|
+ session, originalSecret := NewWebSession("agent", "ip")
|
|
|
+ originalID := session.ID
|
|
|
+ originalHash := bytes.Clone(session.SecretHash)
|
|
|
+ originalCSRF := session.CSRF()
|
|
|
+
|
|
|
+ // Bind a user so we can verify Rotate preserves the user binding.
|
|
|
+ session.SetUser(&User{ID: 42})
|
|
|
+
|
|
|
+ oldID, newSecret := session.Rotate()
|
|
|
+
|
|
|
+ if oldID != originalID {
|
|
|
+ t.Errorf("Rotate returned oldID = %q, want %q", oldID, originalID)
|
|
|
+ }
|
|
|
+ if newSecret == "" {
|
|
|
+ t.Error("Rotate returned an empty new secret")
|
|
|
+ }
|
|
|
+ if newSecret == originalSecret {
|
|
|
+ t.Error("Rotate returned the same secret as before")
|
|
|
+ }
|
|
|
+ if session.ID == originalID {
|
|
|
+ t.Error("Rotate did not change the session ID")
|
|
|
+ }
|
|
|
+ if bytes.Equal(session.SecretHash, originalHash) {
|
|
|
+ t.Error("Rotate did not change the SecretHash")
|
|
|
+ }
|
|
|
+ if session.VerifySecret(originalSecret) {
|
|
|
+ t.Error("VerifySecret must reject the pre-rotation secret")
|
|
|
+ }
|
|
|
+ if !session.VerifySecret(newSecret) {
|
|
|
+ t.Error("VerifySecret must accept the post-rotation secret")
|
|
|
+ }
|
|
|
+ if session.CSRF() != originalCSRF {
|
|
|
+ t.Error("Rotate must preserve the CSRF token so in-flight forms remain valid")
|
|
|
+ }
|
|
|
+ if !session.IsAuthenticated() {
|
|
|
+ t.Error("Rotate must preserve the user binding")
|
|
|
+ }
|
|
|
+ if id, _ := session.UserID(); id != 42 {
|
|
|
+ t.Errorf("Rotate corrupted user ID: got %d, want 42", id)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_VerifySecret(t *testing.T) {
|
|
|
+ good, goodSecret := NewWebSession("", "")
|
|
|
+
|
|
|
+ testCases := []struct {
|
|
|
+ name string
|
|
|
+ hash []byte
|
|
|
+ secret string
|
|
|
+ want bool
|
|
|
+ }{
|
|
|
+ {"correct secret", good.SecretHash, goodSecret, true},
|
|
|
+ {"wrong secret", good.SecretHash, "not-the-right-secret", false},
|
|
|
+ {"empty secret", good.SecretHash, "", false},
|
|
|
+ {"nil hash", nil, goodSecret, false},
|
|
|
+ {"empty hash and secret", nil, "", false},
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tc := range testCases {
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
+ s := &WebSession{SecretHash: tc.hash}
|
|
|
+ if got := s.VerifySecret(tc.secret); got != tc.want {
|
|
|
+ t.Errorf("VerifySecret(%q) = %v, want %v", tc.secret, got, tc.want)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_UserBindingLifecycle(t *testing.T) {
|
|
|
+ session, _ := NewWebSession("", "")
|
|
|
+
|
|
|
+ if session.IsAuthenticated() {
|
|
|
+ t.Error("a fresh session must not be authenticated")
|
|
|
+ }
|
|
|
+ if id, ok := session.UserID(); ok || id != 0 {
|
|
|
+ t.Errorf("UserID() = (%d, %v), want (0, false)", id, ok)
|
|
|
+ }
|
|
|
+
|
|
|
+ user := &User{ID: 99, Language: "fr_FR", Theme: "dark_serif"}
|
|
|
+ session.SetUser(user)
|
|
|
+
|
|
|
+ if !session.IsAuthenticated() {
|
|
|
+ t.Error("session must be authenticated after SetUser")
|
|
|
+ }
|
|
|
+ if id, ok := session.UserID(); !ok || id != 99 {
|
|
|
+ t.Errorf("UserID() = (%d, %v), want (99, true)", id, ok)
|
|
|
+ }
|
|
|
+ if session.Language() != "fr_FR" {
|
|
|
+ t.Errorf("SetUser did not copy Language: got %q, want %q", session.Language(), "fr_FR")
|
|
|
+ }
|
|
|
+ if session.Theme() != "dark_serif" {
|
|
|
+ t.Errorf("SetUser did not copy Theme: got %q, want %q", session.Theme(), "dark_serif")
|
|
|
+ }
|
|
|
+ if !session.IsDirty() {
|
|
|
+ t.Error("SetUser must mark the session dirty")
|
|
|
+ }
|
|
|
+
|
|
|
+ session.ClearUser()
|
|
|
+ if session.IsAuthenticated() {
|
|
|
+ t.Error("session must not be authenticated after ClearUser")
|
|
|
+ }
|
|
|
+ if id, ok := session.UserID(); ok || id != 0 {
|
|
|
+ t.Errorf("UserID() after ClearUser = (%d, %v), want (0, false)", id, ok)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_SetUser_NilIsNoop(t *testing.T) {
|
|
|
+ session, _ := NewWebSession("", "")
|
|
|
+ session.SetUser(nil)
|
|
|
+
|
|
|
+ if session.IsAuthenticated() {
|
|
|
+ t.Error("SetUser(nil) must not authenticate the session")
|
|
|
+ }
|
|
|
+ if session.IsDirty() {
|
|
|
+ t.Error("SetUser(nil) must not mark the session dirty")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_UserIDStorageRoundTrip(t *testing.T) {
|
|
|
+ testCases := []struct {
|
|
|
+ name string
|
|
|
+ in sql.NullInt64
|
|
|
+ }{
|
|
|
+ {"null", sql.NullInt64{}},
|
|
|
+ {"zero valid", sql.NullInt64{Int64: 0, Valid: true}},
|
|
|
+ {"positive valid", sql.NullInt64{Int64: 42, Valid: true}},
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tc := range testCases {
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
+ session := &WebSession{}
|
|
|
+ session.ScanUserID(tc.in)
|
|
|
+
|
|
|
+ if got := session.NullUserID(); got != tc.in {
|
|
|
+ t.Errorf("round-trip = %+v, want %+v", got, tc.in)
|
|
|
+ }
|
|
|
+ if got := session.IsAuthenticated(); got != tc.in.Valid {
|
|
|
+ t.Errorf("IsAuthenticated() = %v, want %v", got, tc.in.Valid)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_ScanUserID_ClearsPreviousValue(t *testing.T) {
|
|
|
+ session := &WebSession{}
|
|
|
+ session.ScanUserID(sql.NullInt64{Int64: 1, Valid: true})
|
|
|
+ session.ScanUserID(sql.NullInt64{})
|
|
|
+
|
|
|
+ if session.IsAuthenticated() {
|
|
|
+ t.Error("ScanUserID with an invalid value must clear the user binding")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_LanguageAndThemeDefaults(t *testing.T) {
|
|
|
+ session := &WebSession{}
|
|
|
+
|
|
|
+ if got := session.Language(); got != defaultSessionLanguage {
|
|
|
+ t.Errorf("default Language() = %q, want %q", got, defaultSessionLanguage)
|
|
|
+ }
|
|
|
+ if got := session.Theme(); got != defaultSessionTheme {
|
|
|
+ t.Errorf("default Theme() = %q, want %q", got, defaultSessionTheme)
|
|
|
+ }
|
|
|
+
|
|
|
+ session.SetLanguage("de_DE")
|
|
|
+ session.SetTheme("light_sans_serif")
|
|
|
+
|
|
|
+ if got := session.Language(); got != "de_DE" {
|
|
|
+ t.Errorf("Language() = %q, want %q", got, "de_DE")
|
|
|
+ }
|
|
|
+ if got := session.Theme(); got != "light_sans_serif" {
|
|
|
+ t.Errorf("Theme() = %q, want %q", got, "light_sans_serif")
|
|
|
+ }
|
|
|
+ if !session.IsDirty() {
|
|
|
+ t.Error("SetLanguage/SetTheme must mark the session dirty")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_OAuth2FlowLifecycle(t *testing.T) {
|
|
|
+ session := &WebSession{}
|
|
|
+
|
|
|
+ if session.OAuth2State() != "" {
|
|
|
+ t.Error("OAuth2State() must be empty by default")
|
|
|
+ }
|
|
|
+ if session.OAuth2CodeVerifier() != "" {
|
|
|
+ t.Error("OAuth2CodeVerifier() must be empty by default")
|
|
|
+ }
|
|
|
+
|
|
|
+ session.StartOAuth2Flow("state-token", "code-verifier")
|
|
|
+
|
|
|
+ if got := session.OAuth2State(); got != "state-token" {
|
|
|
+ t.Errorf("OAuth2State() = %q, want %q", got, "state-token")
|
|
|
+ }
|
|
|
+ if got := session.OAuth2CodeVerifier(); got != "code-verifier" {
|
|
|
+ t.Errorf("OAuth2CodeVerifier() = %q, want %q", got, "code-verifier")
|
|
|
+ }
|
|
|
+ if !session.IsDirty() {
|
|
|
+ t.Error("StartOAuth2Flow must mark the session dirty")
|
|
|
+ }
|
|
|
+
|
|
|
+ session.ClearOAuth2Flow()
|
|
|
+
|
|
|
+ if session.OAuth2State() != "" {
|
|
|
+ t.Errorf("OAuth2State() after Clear = %q, want empty", session.OAuth2State())
|
|
|
+ }
|
|
|
+ if session.OAuth2CodeVerifier() != "" {
|
|
|
+ t.Errorf("OAuth2CodeVerifier() after Clear = %q, want empty", session.OAuth2CodeVerifier())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_ConsumeMessages(t *testing.T) {
|
|
|
+ t.Run("no messages", func(t *testing.T) {
|
|
|
+ session := &WebSession{}
|
|
|
+
|
|
|
+ success, errMsg := session.ConsumeMessages()
|
|
|
+ if success != "" || errMsg != "" {
|
|
|
+ t.Errorf("ConsumeMessages() = (%q, %q), want empty", success, errMsg)
|
|
|
+ }
|
|
|
+ if session.IsDirty() {
|
|
|
+ t.Error("ConsumeMessages with no messages must not mark the session dirty")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ t.Run("returns and clears", func(t *testing.T) {
|
|
|
+ session := &WebSession{}
|
|
|
+ session.SetSuccessMessage("saved")
|
|
|
+ session.SetErrorMessage("nope")
|
|
|
+ session.dirty = false // isolate the dirty contribution of ConsumeMessages
|
|
|
+
|
|
|
+ success, errMsg := session.ConsumeMessages()
|
|
|
+ if success != "saved" || errMsg != "nope" {
|
|
|
+ t.Errorf("ConsumeMessages() = (%q, %q), want (%q, %q)", success, errMsg, "saved", "nope")
|
|
|
+ }
|
|
|
+ if !session.IsDirty() {
|
|
|
+ t.Error("ConsumeMessages with messages must mark the session dirty")
|
|
|
+ }
|
|
|
+
|
|
|
+ success, errMsg = session.ConsumeMessages()
|
|
|
+ if success != "" || errMsg != "" {
|
|
|
+ t.Errorf("second ConsumeMessages() = (%q, %q), want empty", success, errMsg)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_ConsumeWebAuthnSession(t *testing.T) {
|
|
|
+ t.Run("no data", func(t *testing.T) {
|
|
|
+ session := &WebSession{}
|
|
|
+
|
|
|
+ if got := session.ConsumeWebAuthnSession(); got != nil {
|
|
|
+ t.Errorf("ConsumeWebAuthnSession() = %v, want nil", got)
|
|
|
+ }
|
|
|
+ if session.IsDirty() {
|
|
|
+ t.Error("ConsumeWebAuthnSession with no data must not mark the session dirty")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ t.Run("returns and clears", func(t *testing.T) {
|
|
|
+ data := &webauthn.SessionData{}
|
|
|
+ session := &WebSession{}
|
|
|
+ session.SetWebAuthn(data)
|
|
|
+ session.dirty = false // isolate the dirty contribution of ConsumeWebAuthnSession
|
|
|
+
|
|
|
+ if got := session.ConsumeWebAuthnSession(); got != data {
|
|
|
+ t.Errorf("ConsumeWebAuthnSession() = %p, want %p", got, data)
|
|
|
+ }
|
|
|
+ if !session.IsDirty() {
|
|
|
+ t.Error("ConsumeWebAuthnSession with data must mark the session dirty")
|
|
|
+ }
|
|
|
+ if got := session.ConsumeWebAuthnSession(); got != nil {
|
|
|
+ t.Errorf("second ConsumeWebAuthnSession() = %v, want nil", got)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_MarkForceRefreshed(t *testing.T) {
|
|
|
+ session := &WebSession{}
|
|
|
+
|
|
|
+ if got := session.LastForceRefresh(); !got.IsZero() {
|
|
|
+ t.Errorf("default LastForceRefresh() = %v, want zero time", got)
|
|
|
+ }
|
|
|
+
|
|
|
+ before := time.Now().UTC()
|
|
|
+ session.MarkForceRefreshed()
|
|
|
+ after := time.Now().UTC()
|
|
|
+
|
|
|
+ got := session.LastForceRefresh()
|
|
|
+ if got.Before(before) || got.After(after) {
|
|
|
+ t.Errorf("LastForceRefresh() = %v, want between %v and %v", got, before, after)
|
|
|
+ }
|
|
|
+ if !session.IsDirty() {
|
|
|
+ t.Error("MarkForceRefreshed must mark the session dirty")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_StateRoundTrip(t *testing.T) {
|
|
|
+ original := &WebSession{}
|
|
|
+ original.SetLanguage("de_DE")
|
|
|
+ original.SetTheme("light_sans_serif")
|
|
|
+ original.SetSuccessMessage("saved")
|
|
|
+ original.SetErrorMessage("oops")
|
|
|
+ original.StartOAuth2Flow("state-token", "code-verifier")
|
|
|
+ original.MarkForceRefreshed()
|
|
|
+ originalRefreshAt := original.LastForceRefresh()
|
|
|
+
|
|
|
+ data, err := original.MarshalState()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("MarshalState() error: %v", err)
|
|
|
+ }
|
|
|
+ if !json.Valid(data) {
|
|
|
+ t.Errorf("MarshalState() produced invalid JSON: %s", data)
|
|
|
+ }
|
|
|
+
|
|
|
+ restored := &WebSession{}
|
|
|
+ if err := restored.UnmarshalState(data); err != nil {
|
|
|
+ t.Fatalf("UnmarshalState() error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if got := restored.Language(); got != "de_DE" {
|
|
|
+ t.Errorf("Language() = %q, want %q", got, "de_DE")
|
|
|
+ }
|
|
|
+ if got := restored.Theme(); got != "light_sans_serif" {
|
|
|
+ t.Errorf("Theme() = %q, want %q", got, "light_sans_serif")
|
|
|
+ }
|
|
|
+ if got := restored.OAuth2State(); got != "state-token" {
|
|
|
+ t.Errorf("OAuth2State() = %q, want %q", got, "state-token")
|
|
|
+ }
|
|
|
+ if got := restored.OAuth2CodeVerifier(); got != "code-verifier" {
|
|
|
+ t.Errorf("OAuth2CodeVerifier() = %q, want %q", got, "code-verifier")
|
|
|
+ }
|
|
|
+ if got := restored.LastForceRefresh(); !got.Equal(originalRefreshAt) {
|
|
|
+ t.Errorf("LastForceRefresh() = %v, want %v", got, originalRefreshAt)
|
|
|
+ }
|
|
|
+
|
|
|
+ success, errMsg := restored.ConsumeMessages()
|
|
|
+ if success != "saved" || errMsg != "oops" {
|
|
|
+ t.Errorf("ConsumeMessages() = (%q, %q), want (%q, %q)", success, errMsg, "saved", "oops")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestWebSession_UnmarshalState_EmptyDataResetsState(t *testing.T) {
|
|
|
+ session := &WebSession{}
|
|
|
+ session.SetLanguage("fr_FR")
|
|
|
+ session.StartOAuth2Flow("s", "v")
|
|
|
+
|
|
|
+ if err := session.UnmarshalState(nil); err != nil {
|
|
|
+ t.Fatalf("UnmarshalState(nil) error: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if got := session.Language(); got != defaultSessionLanguage {
|
|
|
+ t.Errorf("UnmarshalState(nil) did not reset Language: got %q", got)
|
|
|
+ }
|
|
|
+ if session.OAuth2State() != "" {
|
|
|
+ t.Error("UnmarshalState(nil) did not reset OAuth2 state")
|
|
|
+ }
|
|
|
+}
|