sessions.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. package auth
  2. import (
  3. "os"
  4. "sync"
  5. "time"
  6. "github.com/OliveTin/OliveTin/internal/config"
  7. "github.com/sirupsen/logrus"
  8. "gopkg.in/yaml.v3"
  9. )
  10. // Session management for user authentication
  11. type UserSession struct {
  12. Username string
  13. Expiry int64
  14. }
  15. type SessionProvider struct {
  16. Sessions map[string]*UserSession
  17. }
  18. type SessionStorage struct {
  19. Providers map[string]*SessionProvider
  20. }
  21. var (
  22. sessionStorage *SessionStorage
  23. sessionStorageMutex sync.RWMutex
  24. oauth2SessionRevoker func(sid string)
  25. )
  26. func init() {
  27. sessionStorage = &SessionStorage{
  28. Providers: make(map[string]*SessionProvider),
  29. }
  30. }
  31. // RegisterUserSession registers a user session
  32. func RegisterUserSession(cfg *config.Config, provider string, sid string, username string) {
  33. sessionStorageMutex.Lock()
  34. defer sessionStorageMutex.Unlock()
  35. if sessionStorage.Providers[provider] == nil {
  36. sessionStorage.Providers[provider] = &SessionProvider{
  37. Sessions: make(map[string]*UserSession),
  38. }
  39. }
  40. if sessionStorage.Providers == nil {
  41. sessionStorage.Providers = make(map[string]*SessionProvider)
  42. }
  43. sessionStorage.Providers[provider].Sessions[sid] = &UserSession{
  44. Username: username,
  45. Expiry: time.Now().Unix() + 31556952, // 1 year
  46. }
  47. saveUserSessions(cfg)
  48. }
  49. // RegisterOAuth2SessionRevoker registers a callback to revoke OAuth2 sessions on logout.
  50. // OAuth2 uses its own session storage; the API calls this when provider is oauth2.
  51. func RegisterOAuth2SessionRevoker(fn func(sid string)) {
  52. oauth2SessionRevoker = fn
  53. }
  54. // RevokeSessionForProvider invalidates the session for the given provider and SID (e.g. on logout).
  55. // Local auth uses shared SessionStorage; OAuth2 uses a separate storage and revoker.
  56. func RevokeSessionForProvider(cfg *config.Config, provider string, sid string) {
  57. if sid == "" {
  58. return
  59. }
  60. if provider == "oauth2" && oauth2SessionRevoker != nil {
  61. oauth2SessionRevoker(sid)
  62. return
  63. }
  64. RevokeUserSession(cfg, provider, sid)
  65. }
  66. // RevokeUserSession removes a session from storage so it can no longer be used (e.g. on logout).
  67. func RevokeUserSession(cfg *config.Config, provider string, sid string) {
  68. sessionStorageMutex.Lock()
  69. defer sessionStorageMutex.Unlock()
  70. if sessionStorage.Providers[provider] != nil {
  71. delete(sessionStorage.Providers[provider].Sessions, sid)
  72. if cfg != nil {
  73. saveUserSessions(cfg)
  74. }
  75. }
  76. }
  77. // GetUserSession retrieves a user session
  78. func GetUserSession(provider string, sid string) *UserSession {
  79. sessionStorageMutex.Lock()
  80. defer sessionStorageMutex.Unlock()
  81. if sessionStorage.Providers[provider] == nil {
  82. return nil
  83. }
  84. session := sessionStorage.Providers[provider].Sessions[sid]
  85. if session == nil {
  86. return nil
  87. }
  88. if session.Expiry < time.Now().Unix() {
  89. delete(sessionStorage.Providers[provider].Sessions, sid)
  90. return nil
  91. }
  92. return session
  93. }
  94. // LoadUserSessions loads sessions from disk
  95. func LoadUserSessions(cfg *config.Config) {
  96. sessionStorageMutex.Lock()
  97. defer sessionStorageMutex.Unlock()
  98. data, err := os.ReadFile(cfg.GetDir() + "/sessions.yaml")
  99. if err != nil {
  100. logrus.WithError(err).Warn("Failed to read sessions.yaml file")
  101. ensureEmptySessionStorage()
  102. return
  103. }
  104. if err := yaml.Unmarshal(data, &sessionStorage); err != nil {
  105. logrus.WithError(err).Error("Failed to unmarshal sessions.yaml")
  106. ensureEmptySessionStorage()
  107. return
  108. }
  109. ensureEmptySessionStorage()
  110. }
  111. func ensureEmptySessionStorage() {
  112. if sessionStorage == nil {
  113. sessionStorage = &SessionStorage{Providers: make(map[string]*SessionProvider)}
  114. }
  115. if sessionStorage.Providers == nil {
  116. sessionStorage.Providers = make(map[string]*SessionProvider)
  117. }
  118. }
  119. func saveUserSessions(cfg *config.Config) {
  120. out, err := yaml.Marshal(sessionStorage)
  121. if err != nil {
  122. logrus.WithError(err).Error("Failed to marshal session storage")
  123. return
  124. }
  125. err = os.WriteFile(cfg.GetDir()+"/sessions.yaml", out, 0600)
  126. if err != nil {
  127. logrus.WithError(err).Error("Failed to write sessions.yaml file")
  128. return
  129. }
  130. }