4
0

sessions.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. package auth
  2. import (
  3. "os"
  4. "sync"
  5. "time"
  6. "github.com/OliveTin/OliveTin/internal/config"
  7. "github.com/sirupsen/logrus"
  8. "gopkg.in/yaml.v3"
  9. )
  10. // Session management for user authentication
  11. type UserSession struct {
  12. Username string
  13. Expiry int64
  14. }
  15. type SessionProvider struct {
  16. Sessions map[string]*UserSession
  17. }
  18. type SessionStorage struct {
  19. Providers map[string]*SessionProvider
  20. }
  21. var (
  22. sessionStorage *SessionStorage
  23. sessionStorageMutex sync.RWMutex
  24. )
  25. func init() {
  26. sessionStorage = &SessionStorage{
  27. Providers: make(map[string]*SessionProvider),
  28. }
  29. }
  30. // RegisterUserSession registers a user session
  31. func RegisterUserSession(cfg *config.Config, provider string, sid string, username string) {
  32. sessionStorageMutex.Lock()
  33. defer sessionStorageMutex.Unlock()
  34. if sessionStorage.Providers[provider] == nil {
  35. sessionStorage.Providers[provider] = &SessionProvider{
  36. Sessions: make(map[string]*UserSession),
  37. }
  38. }
  39. if sessionStorage.Providers == nil {
  40. sessionStorage.Providers = make(map[string]*SessionProvider)
  41. }
  42. sessionStorage.Providers[provider].Sessions[sid] = &UserSession{
  43. Username: username,
  44. Expiry: time.Now().Unix() + 31556952, // 1 year
  45. }
  46. saveUserSessions(cfg)
  47. }
  48. // GetUserSession retrieves a user session
  49. func GetUserSession(provider string, sid string) *UserSession {
  50. sessionStorageMutex.Lock()
  51. defer sessionStorageMutex.Unlock()
  52. if sessionStorage.Providers[provider] == nil {
  53. return nil
  54. }
  55. session := sessionStorage.Providers[provider].Sessions[sid]
  56. if session == nil {
  57. return nil
  58. }
  59. if session.Expiry < time.Now().Unix() {
  60. delete(sessionStorage.Providers[provider].Sessions, sid)
  61. return nil
  62. }
  63. return session
  64. }
  65. // LoadUserSessions loads sessions from disk
  66. func LoadUserSessions(cfg *config.Config) {
  67. sessionStorageMutex.Lock()
  68. defer sessionStorageMutex.Unlock()
  69. data, err := os.ReadFile(cfg.GetDir() + "/sessions.yaml")
  70. if err != nil {
  71. logrus.WithError(err).Warn("Failed to read sessions.yaml file")
  72. ensureEmptySessionStorage()
  73. return
  74. }
  75. if err := yaml.Unmarshal(data, &sessionStorage); err != nil {
  76. logrus.WithError(err).Error("Failed to unmarshal sessions.yaml")
  77. ensureEmptySessionStorage()
  78. return
  79. }
  80. ensureEmptySessionStorage()
  81. }
  82. func ensureEmptySessionStorage() {
  83. if sessionStorage == nil {
  84. sessionStorage = &SessionStorage{Providers: make(map[string]*SessionProvider)}
  85. }
  86. if sessionStorage.Providers == nil {
  87. sessionStorage.Providers = make(map[string]*SessionProvider)
  88. }
  89. }
  90. func saveUserSessions(cfg *config.Config) {
  91. out, err := yaml.Marshal(sessionStorage)
  92. if err != nil {
  93. logrus.WithError(err).Error("Failed to marshal session storage")
  94. return
  95. }
  96. err = os.WriteFile(cfg.GetDir()+"/sessions.yaml", out, 0600)
  97. if err != nil {
  98. logrus.WithError(err).Error("Failed to write sessions.yaml file")
  99. return
  100. }
  101. }