| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429 |
- // SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
- // SPDX-License-Identifier: Apache-2.0
- package model
- import (
- "bytes"
- "database/sql"
- "encoding/json"
- "testing"
- "time"
- "github.com/go-webauthn/webauthn/webauthn"
- )
- func TestNewWebSession(t *testing.T) {
- const userAgent = "test-agent"
- const ip = "127.0.0.1"
- session, secret := NewWebSession(userAgent, ip)
- if session == nil {
- t.Fatal("NewWebSession returned a nil session")
- }
- if secret == "" {
- t.Error("NewWebSession returned an empty secret")
- }
- if session.ID == "" {
- t.Error("NewWebSession produced an empty ID")
- }
- if session.ID == secret {
- t.Error("session ID and secret must not be equal")
- }
- if len(session.SecretHash) == 0 {
- t.Error("NewWebSession produced an empty SecretHash")
- }
- if session.CSRF() == "" {
- t.Error("NewWebSession produced an empty CSRF token")
- }
- if session.UserAgent != userAgent {
- t.Errorf("UserAgent = %q, want %q", session.UserAgent, userAgent)
- }
- if session.IP != ip {
- t.Errorf("IP = %q, want %q", session.IP, ip)
- }
- if session.IsAuthenticated() {
- t.Error("a fresh session must not be authenticated")
- }
- if session.IsDirty() {
- t.Error("a fresh session must not be dirty")
- }
- if !session.VerifySecret(secret) {
- t.Error("VerifySecret rejected the secret returned by NewWebSession")
- }
- }
- func TestNewWebSession_ProducesUniqueIdentities(t *testing.T) {
- s1, secret1 := NewWebSession("", "")
- s2, secret2 := NewWebSession("", "")
- if s1.ID == s2.ID {
- t.Error("successive NewWebSession calls produced the same ID")
- }
- if secret1 == secret2 {
- t.Error("successive NewWebSession calls produced the same secret")
- }
- if bytes.Equal(s1.SecretHash, s2.SecretHash) {
- t.Error("successive NewWebSession calls produced the same SecretHash")
- }
- if s1.CSRF() == s2.CSRF() {
- t.Error("successive NewWebSession calls produced the same CSRF token")
- }
- }
- func TestWebSession_Rotate(t *testing.T) {
- session, originalSecret := NewWebSession("agent", "ip")
- originalID := session.ID
- originalHash := bytes.Clone(session.SecretHash)
- originalCSRF := session.CSRF()
- // Bind a user so we can verify Rotate preserves the user binding.
- session.SetUser(&User{ID: 42})
- oldID, newSecret := session.Rotate()
- if oldID != originalID {
- t.Errorf("Rotate returned oldID = %q, want %q", oldID, originalID)
- }
- if newSecret == "" {
- t.Error("Rotate returned an empty new secret")
- }
- if newSecret == originalSecret {
- t.Error("Rotate returned the same secret as before")
- }
- if session.ID == originalID {
- t.Error("Rotate did not change the session ID")
- }
- if bytes.Equal(session.SecretHash, originalHash) {
- t.Error("Rotate did not change the SecretHash")
- }
- if session.VerifySecret(originalSecret) {
- t.Error("VerifySecret must reject the pre-rotation secret")
- }
- if !session.VerifySecret(newSecret) {
- t.Error("VerifySecret must accept the post-rotation secret")
- }
- if session.CSRF() != originalCSRF {
- t.Error("Rotate must preserve the CSRF token so in-flight forms remain valid")
- }
- if !session.IsAuthenticated() {
- t.Error("Rotate must preserve the user binding")
- }
- if id, _ := session.UserID(); id != 42 {
- t.Errorf("Rotate corrupted user ID: got %d, want 42", id)
- }
- }
- func TestWebSession_VerifySecret(t *testing.T) {
- good, goodSecret := NewWebSession("", "")
- testCases := []struct {
- name string
- hash []byte
- secret string
- want bool
- }{
- {"correct secret", good.SecretHash, goodSecret, true},
- {"wrong secret", good.SecretHash, "not-the-right-secret", false},
- {"empty secret", good.SecretHash, "", false},
- {"nil hash", nil, goodSecret, false},
- {"empty hash and secret", nil, "", false},
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- s := &WebSession{SecretHash: tc.hash}
- if got := s.VerifySecret(tc.secret); got != tc.want {
- t.Errorf("VerifySecret(%q) = %v, want %v", tc.secret, got, tc.want)
- }
- })
- }
- }
- func TestWebSession_UserBindingLifecycle(t *testing.T) {
- session, _ := NewWebSession("", "")
- if session.IsAuthenticated() {
- t.Error("a fresh session must not be authenticated")
- }
- if id, ok := session.UserID(); ok || id != 0 {
- t.Errorf("UserID() = (%d, %v), want (0, false)", id, ok)
- }
- user := &User{ID: 99, Language: "fr_FR", Theme: "dark_serif"}
- session.SetUser(user)
- if !session.IsAuthenticated() {
- t.Error("session must be authenticated after SetUser")
- }
- if id, ok := session.UserID(); !ok || id != 99 {
- t.Errorf("UserID() = (%d, %v), want (99, true)", id, ok)
- }
- if session.Language() != "fr_FR" {
- t.Errorf("SetUser did not copy Language: got %q, want %q", session.Language(), "fr_FR")
- }
- if session.Theme() != "dark_serif" {
- t.Errorf("SetUser did not copy Theme: got %q, want %q", session.Theme(), "dark_serif")
- }
- if !session.IsDirty() {
- t.Error("SetUser must mark the session dirty")
- }
- session.ClearUser()
- if session.IsAuthenticated() {
- t.Error("session must not be authenticated after ClearUser")
- }
- if id, ok := session.UserID(); ok || id != 0 {
- t.Errorf("UserID() after ClearUser = (%d, %v), want (0, false)", id, ok)
- }
- }
- func TestWebSession_SetUser_NilIsNoop(t *testing.T) {
- session, _ := NewWebSession("", "")
- session.SetUser(nil)
- if session.IsAuthenticated() {
- t.Error("SetUser(nil) must not authenticate the session")
- }
- if session.IsDirty() {
- t.Error("SetUser(nil) must not mark the session dirty")
- }
- }
- func TestWebSession_UserIDStorageRoundTrip(t *testing.T) {
- testCases := []struct {
- name string
- in sql.NullInt64
- }{
- {"null", sql.NullInt64{}},
- {"zero valid", sql.NullInt64{Int64: 0, Valid: true}},
- {"positive valid", sql.NullInt64{Int64: 42, Valid: true}},
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- session := &WebSession{}
- session.ScanUserID(tc.in)
- if got := session.NullUserID(); got != tc.in {
- t.Errorf("round-trip = %+v, want %+v", got, tc.in)
- }
- if got := session.IsAuthenticated(); got != tc.in.Valid {
- t.Errorf("IsAuthenticated() = %v, want %v", got, tc.in.Valid)
- }
- })
- }
- }
- func TestWebSession_ScanUserID_ClearsPreviousValue(t *testing.T) {
- session := &WebSession{}
- session.ScanUserID(sql.NullInt64{Int64: 1, Valid: true})
- session.ScanUserID(sql.NullInt64{})
- if session.IsAuthenticated() {
- t.Error("ScanUserID with an invalid value must clear the user binding")
- }
- }
- func TestWebSession_LanguageAndThemeDefaults(t *testing.T) {
- session := &WebSession{}
- if got := session.Language(); got != defaultSessionLanguage {
- t.Errorf("default Language() = %q, want %q", got, defaultSessionLanguage)
- }
- if got := session.Theme(); got != defaultSessionTheme {
- t.Errorf("default Theme() = %q, want %q", got, defaultSessionTheme)
- }
- session.SetLanguage("de_DE")
- session.SetTheme("light_sans_serif")
- if got := session.Language(); got != "de_DE" {
- t.Errorf("Language() = %q, want %q", got, "de_DE")
- }
- if got := session.Theme(); got != "light_sans_serif" {
- t.Errorf("Theme() = %q, want %q", got, "light_sans_serif")
- }
- if !session.IsDirty() {
- t.Error("SetLanguage/SetTheme must mark the session dirty")
- }
- }
- func TestWebSession_OAuth2FlowLifecycle(t *testing.T) {
- session := &WebSession{}
- if session.OAuth2State() != "" {
- t.Error("OAuth2State() must be empty by default")
- }
- if session.OAuth2CodeVerifier() != "" {
- t.Error("OAuth2CodeVerifier() must be empty by default")
- }
- session.StartOAuth2Flow("state-token", "code-verifier")
- if got := session.OAuth2State(); got != "state-token" {
- t.Errorf("OAuth2State() = %q, want %q", got, "state-token")
- }
- if got := session.OAuth2CodeVerifier(); got != "code-verifier" {
- t.Errorf("OAuth2CodeVerifier() = %q, want %q", got, "code-verifier")
- }
- if !session.IsDirty() {
- t.Error("StartOAuth2Flow must mark the session dirty")
- }
- session.ClearOAuth2Flow()
- if session.OAuth2State() != "" {
- t.Errorf("OAuth2State() after Clear = %q, want empty", session.OAuth2State())
- }
- if session.OAuth2CodeVerifier() != "" {
- t.Errorf("OAuth2CodeVerifier() after Clear = %q, want empty", session.OAuth2CodeVerifier())
- }
- }
- func TestWebSession_ConsumeMessages(t *testing.T) {
- t.Run("no messages", func(t *testing.T) {
- session := &WebSession{}
- success, errMsg := session.ConsumeMessages()
- if success != "" || errMsg != "" {
- t.Errorf("ConsumeMessages() = (%q, %q), want empty", success, errMsg)
- }
- if session.IsDirty() {
- t.Error("ConsumeMessages with no messages must not mark the session dirty")
- }
- })
- t.Run("returns and clears", func(t *testing.T) {
- session := &WebSession{}
- session.SetSuccessMessage("saved")
- session.SetErrorMessage("nope")
- session.dirty = false // isolate the dirty contribution of ConsumeMessages
- success, errMsg := session.ConsumeMessages()
- if success != "saved" || errMsg != "nope" {
- t.Errorf("ConsumeMessages() = (%q, %q), want (%q, %q)", success, errMsg, "saved", "nope")
- }
- if !session.IsDirty() {
- t.Error("ConsumeMessages with messages must mark the session dirty")
- }
- success, errMsg = session.ConsumeMessages()
- if success != "" || errMsg != "" {
- t.Errorf("second ConsumeMessages() = (%q, %q), want empty", success, errMsg)
- }
- })
- }
- func TestWebSession_ConsumeWebAuthnSession(t *testing.T) {
- t.Run("no data", func(t *testing.T) {
- session := &WebSession{}
- if got := session.ConsumeWebAuthnSession(); got != nil {
- t.Errorf("ConsumeWebAuthnSession() = %v, want nil", got)
- }
- if session.IsDirty() {
- t.Error("ConsumeWebAuthnSession with no data must not mark the session dirty")
- }
- })
- t.Run("returns and clears", func(t *testing.T) {
- data := &webauthn.SessionData{}
- session := &WebSession{}
- session.SetWebAuthn(data)
- session.dirty = false // isolate the dirty contribution of ConsumeWebAuthnSession
- if got := session.ConsumeWebAuthnSession(); got != data {
- t.Errorf("ConsumeWebAuthnSession() = %p, want %p", got, data)
- }
- if !session.IsDirty() {
- t.Error("ConsumeWebAuthnSession with data must mark the session dirty")
- }
- if got := session.ConsumeWebAuthnSession(); got != nil {
- t.Errorf("second ConsumeWebAuthnSession() = %v, want nil", got)
- }
- })
- }
- func TestWebSession_MarkForceRefreshed(t *testing.T) {
- session := &WebSession{}
- if got := session.LastForceRefresh(); !got.IsZero() {
- t.Errorf("default LastForceRefresh() = %v, want zero time", got)
- }
- before := time.Now().UTC()
- session.MarkForceRefreshed()
- after := time.Now().UTC()
- got := session.LastForceRefresh()
- if got.Before(before) || got.After(after) {
- t.Errorf("LastForceRefresh() = %v, want between %v and %v", got, before, after)
- }
- if !session.IsDirty() {
- t.Error("MarkForceRefreshed must mark the session dirty")
- }
- }
- func TestWebSession_StateRoundTrip(t *testing.T) {
- original := &WebSession{}
- original.SetLanguage("de_DE")
- original.SetTheme("light_sans_serif")
- original.SetSuccessMessage("saved")
- original.SetErrorMessage("oops")
- original.StartOAuth2Flow("state-token", "code-verifier")
- original.MarkForceRefreshed()
- originalRefreshAt := original.LastForceRefresh()
- data, err := original.MarshalState()
- if err != nil {
- t.Fatalf("MarshalState() error: %v", err)
- }
- if !json.Valid(data) {
- t.Errorf("MarshalState() produced invalid JSON: %s", data)
- }
- restored := &WebSession{}
- if err := restored.UnmarshalState(data); err != nil {
- t.Fatalf("UnmarshalState() error: %v", err)
- }
- if got := restored.Language(); got != "de_DE" {
- t.Errorf("Language() = %q, want %q", got, "de_DE")
- }
- if got := restored.Theme(); got != "light_sans_serif" {
- t.Errorf("Theme() = %q, want %q", got, "light_sans_serif")
- }
- if got := restored.OAuth2State(); got != "state-token" {
- t.Errorf("OAuth2State() = %q, want %q", got, "state-token")
- }
- if got := restored.OAuth2CodeVerifier(); got != "code-verifier" {
- t.Errorf("OAuth2CodeVerifier() = %q, want %q", got, "code-verifier")
- }
- if got := restored.LastForceRefresh(); !got.Equal(originalRefreshAt) {
- t.Errorf("LastForceRefresh() = %v, want %v", got, originalRefreshAt)
- }
- success, errMsg := restored.ConsumeMessages()
- if success != "saved" || errMsg != "oops" {
- t.Errorf("ConsumeMessages() = (%q, %q), want (%q, %q)", success, errMsg, "saved", "oops")
- }
- }
- func TestWebSession_UnmarshalState_EmptyDataResetsState(t *testing.T) {
- session := &WebSession{}
- session.SetLanguage("fr_FR")
- session.StartOAuth2Flow("s", "v")
- if err := session.UnmarshalState(nil); err != nil {
- t.Fatalf("UnmarshalState(nil) error: %v", err)
- }
- if got := session.Language(); got != defaultSessionLanguage {
- t.Errorf("UnmarshalState(nil) did not reset Language: got %q", got)
- }
- if session.OAuth2State() != "" {
- t.Error("UnmarshalState(nil) did not reset OAuth2 state")
- }
- }
|