context_test.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. // SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
  2. // SPDX-License-Identifier: Apache-2.0
  3. package request // import "miniflux.app/v2/internal/http/request"
  4. import (
  5. "context"
  6. "net/http"
  7. "testing"
  8. "miniflux.app/v2/internal/model"
  9. )
  10. func newRequestWithWebSession(session *model.WebSession) *http.Request {
  11. r, _ := http.NewRequest("GET", "http://example.org", nil)
  12. ctx := context.WithValue(r.Context(), WebSessionContextKey, session)
  13. return r.WithContext(ctx)
  14. }
  15. func TestContextStringValue(t *testing.T) {
  16. r, _ := http.NewRequest("GET", "http://example.org", nil)
  17. ctx := r.Context()
  18. ctx = context.WithValue(ctx, ClientIPContextKey, "IP")
  19. r = r.WithContext(ctx)
  20. result := getContextStringValue(r, ClientIPContextKey)
  21. expected := "IP"
  22. if result != expected {
  23. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  24. }
  25. }
  26. func TestContextStringValueWithInvalidType(t *testing.T) {
  27. r, _ := http.NewRequest("GET", "http://example.org", nil)
  28. ctx := r.Context()
  29. ctx = context.WithValue(ctx, ClientIPContextKey, 0)
  30. r = r.WithContext(ctx)
  31. result := getContextStringValue(r, ClientIPContextKey)
  32. expected := ""
  33. if result != expected {
  34. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  35. }
  36. }
  37. func TestContextStringValueWhenUnset(t *testing.T) {
  38. r, _ := http.NewRequest("GET", "http://example.org", nil)
  39. result := getContextStringValue(r, ClientIPContextKey)
  40. expected := ""
  41. if result != expected {
  42. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  43. }
  44. }
  45. func TestContextBoolValue(t *testing.T) {
  46. r, _ := http.NewRequest("GET", "http://example.org", nil)
  47. ctx := r.Context()
  48. ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
  49. r = r.WithContext(ctx)
  50. result := getContextBoolValue(r, IsAdminUserContextKey)
  51. expected := true
  52. if result != expected {
  53. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  54. }
  55. }
  56. func TestContextBoolValueWithInvalidType(t *testing.T) {
  57. r, _ := http.NewRequest("GET", "http://example.org", nil)
  58. ctx := r.Context()
  59. ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid")
  60. r = r.WithContext(ctx)
  61. result := getContextBoolValue(r, IsAdminUserContextKey)
  62. expected := false
  63. if result != expected {
  64. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  65. }
  66. }
  67. func TestContextBoolValueWhenUnset(t *testing.T) {
  68. r, _ := http.NewRequest("GET", "http://example.org", nil)
  69. result := getContextBoolValue(r, IsAdminUserContextKey)
  70. expected := false
  71. if result != expected {
  72. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  73. }
  74. }
  75. func TestContextInt64Value(t *testing.T) {
  76. r, _ := http.NewRequest("GET", "http://example.org", nil)
  77. ctx := r.Context()
  78. ctx = context.WithValue(ctx, UserIDContextKey, int64(1234))
  79. r = r.WithContext(ctx)
  80. result := getContextInt64Value(r, UserIDContextKey)
  81. expected := int64(1234)
  82. if result != expected {
  83. t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
  84. }
  85. }
  86. func TestContextInt64ValueWithInvalidType(t *testing.T) {
  87. r, _ := http.NewRequest("GET", "http://example.org", nil)
  88. ctx := r.Context()
  89. ctx = context.WithValue(ctx, UserIDContextKey, "invalid")
  90. r = r.WithContext(ctx)
  91. result := getContextInt64Value(r, UserIDContextKey)
  92. expected := int64(0)
  93. if result != expected {
  94. t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
  95. }
  96. }
  97. func TestContextInt64ValueWhenUnset(t *testing.T) {
  98. r, _ := http.NewRequest("GET", "http://example.org", nil)
  99. result := getContextInt64Value(r, UserIDContextKey)
  100. expected := int64(0)
  101. if result != expected {
  102. t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
  103. }
  104. }
  105. func TestIsAdmin(t *testing.T) {
  106. r, _ := http.NewRequest("GET", "http://example.org", nil)
  107. result := IsAdminUser(r)
  108. expected := false
  109. if result != expected {
  110. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  111. }
  112. ctx := r.Context()
  113. ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
  114. r = r.WithContext(ctx)
  115. result = IsAdminUser(r)
  116. expected = true
  117. if result != expected {
  118. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  119. }
  120. }
  121. func TestIsAuthenticated(t *testing.T) {
  122. r, _ := http.NewRequest("GET", "http://example.org", nil)
  123. result := IsAuthenticated(r)
  124. expected := false
  125. if result != expected {
  126. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  127. }
  128. ctx := r.Context()
  129. ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true)
  130. r = r.WithContext(ctx)
  131. result = IsAuthenticated(r)
  132. expected = true
  133. if result != expected {
  134. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  135. }
  136. session := &model.WebSession{}
  137. session.SetUser(&model.User{ID: 42})
  138. r = newRequestWithWebSession(session)
  139. result = IsAuthenticated(r)
  140. if !result {
  141. t.Errorf("Unexpected context value, got %v instead of true", result)
  142. }
  143. }
  144. func TestUserID(t *testing.T) {
  145. r, _ := http.NewRequest("GET", "http://example.org", nil)
  146. result := UserID(r)
  147. expected := int64(0)
  148. if result != expected {
  149. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  150. }
  151. ctx := r.Context()
  152. ctx = context.WithValue(ctx, UserIDContextKey, int64(123))
  153. r = r.WithContext(ctx)
  154. result = UserID(r)
  155. expected = int64(123)
  156. if result != expected {
  157. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  158. }
  159. session := &model.WebSession{}
  160. session.SetUser(&model.User{ID: 456})
  161. r = newRequestWithWebSession(session)
  162. result = UserID(r)
  163. expected = int64(456)
  164. if result != expected {
  165. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  166. }
  167. }
  168. func TestUserName(t *testing.T) {
  169. r, _ := http.NewRequest("GET", "http://example.org", nil)
  170. result := UserName(r)
  171. expected := "unknown"
  172. if result != expected {
  173. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  174. }
  175. ctx := r.Context()
  176. ctx = context.WithValue(ctx, UserNameContextKey, "jane")
  177. r = r.WithContext(ctx)
  178. result = UserName(r)
  179. expected = "jane"
  180. if result != expected {
  181. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  182. }
  183. }
  184. func TestUserTimezone(t *testing.T) {
  185. r, _ := http.NewRequest("GET", "http://example.org", nil)
  186. result := UserTimezone(r)
  187. expected := "UTC"
  188. if result != expected {
  189. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  190. }
  191. ctx := r.Context()
  192. ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris")
  193. r = r.WithContext(ctx)
  194. result = UserTimezone(r)
  195. expected = "Europe/Paris"
  196. if result != expected {
  197. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  198. }
  199. }
  200. func TestWebSession(t *testing.T) {
  201. r, _ := http.NewRequest("GET", "http://example.org", nil)
  202. if result := WebSession(r); result != nil {
  203. t.Fatalf("Unexpected context value, got %v instead of nil", result)
  204. }
  205. session := &model.WebSession{ID: "session-id"}
  206. ctx := r.Context()
  207. ctx = context.WithValue(ctx, WebSessionContextKey, session)
  208. r = r.WithContext(ctx)
  209. result := WebSession(r)
  210. if result == nil || result.ID != "session-id" {
  211. t.Fatalf("Unexpected context value, got %#v instead of session-id", result)
  212. }
  213. }
  214. func TestClientIP(t *testing.T) {
  215. r, _ := http.NewRequest("GET", "http://example.org", nil)
  216. result := ClientIP(r)
  217. expected := ""
  218. if result != expected {
  219. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  220. }
  221. ctx := r.Context()
  222. ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1")
  223. r = r.WithContext(ctx)
  224. result = ClientIP(r)
  225. expected = "127.0.0.1"
  226. if result != expected {
  227. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  228. }
  229. }
  230. func TestGoogleReaderToken(t *testing.T) {
  231. r, _ := http.NewRequest("GET", "http://example.org", nil)
  232. result := GoogleReaderToken(r)
  233. expected := ""
  234. if result != expected {
  235. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  236. }
  237. ctx := r.Context()
  238. ctx = context.WithValue(ctx, GoogleReaderTokenKey, "token")
  239. r = r.WithContext(ctx)
  240. result = GoogleReaderToken(r)
  241. expected = "token"
  242. if result != expected {
  243. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  244. }
  245. }