context_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. // SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
  2. // SPDX-License-Identifier: Apache-2.0
  3. package request // import "miniflux.app/v2/internal/http/request"
  4. import (
  5. "context"
  6. "net/http"
  7. "testing"
  8. "time"
  9. "miniflux.app/v2/internal/model"
  10. )
  11. func TestContextStringValue(t *testing.T) {
  12. r, _ := http.NewRequest("GET", "http://example.org", nil)
  13. ctx := r.Context()
  14. ctx = context.WithValue(ctx, ClientIPContextKey, "IP")
  15. r = r.WithContext(ctx)
  16. result := getContextStringValue(r, ClientIPContextKey)
  17. expected := "IP"
  18. if result != expected {
  19. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  20. }
  21. }
  22. func TestContextStringValueWithInvalidType(t *testing.T) {
  23. r, _ := http.NewRequest("GET", "http://example.org", nil)
  24. ctx := r.Context()
  25. ctx = context.WithValue(ctx, ClientIPContextKey, 0)
  26. r = r.WithContext(ctx)
  27. result := getContextStringValue(r, ClientIPContextKey)
  28. expected := ""
  29. if result != expected {
  30. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  31. }
  32. }
  33. func TestContextStringValueWhenUnset(t *testing.T) {
  34. r, _ := http.NewRequest("GET", "http://example.org", nil)
  35. result := getContextStringValue(r, ClientIPContextKey)
  36. expected := ""
  37. if result != expected {
  38. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  39. }
  40. }
  41. func TestContextBoolValue(t *testing.T) {
  42. r, _ := http.NewRequest("GET", "http://example.org", nil)
  43. ctx := r.Context()
  44. ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
  45. r = r.WithContext(ctx)
  46. result := getContextBoolValue(r, IsAdminUserContextKey)
  47. expected := true
  48. if result != expected {
  49. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  50. }
  51. }
  52. func TestContextBoolValueWithInvalidType(t *testing.T) {
  53. r, _ := http.NewRequest("GET", "http://example.org", nil)
  54. ctx := r.Context()
  55. ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid")
  56. r = r.WithContext(ctx)
  57. result := getContextBoolValue(r, IsAdminUserContextKey)
  58. expected := false
  59. if result != expected {
  60. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  61. }
  62. }
  63. func TestContextBoolValueWhenUnset(t *testing.T) {
  64. r, _ := http.NewRequest("GET", "http://example.org", nil)
  65. result := getContextBoolValue(r, IsAdminUserContextKey)
  66. expected := false
  67. if result != expected {
  68. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  69. }
  70. }
  71. func TestContextInt64Value(t *testing.T) {
  72. r, _ := http.NewRequest("GET", "http://example.org", nil)
  73. ctx := r.Context()
  74. ctx = context.WithValue(ctx, UserIDContextKey, int64(1234))
  75. r = r.WithContext(ctx)
  76. result := getContextInt64Value(r, UserIDContextKey)
  77. expected := int64(1234)
  78. if result != expected {
  79. t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
  80. }
  81. }
  82. func TestContextInt64ValueWithInvalidType(t *testing.T) {
  83. r, _ := http.NewRequest("GET", "http://example.org", nil)
  84. ctx := r.Context()
  85. ctx = context.WithValue(ctx, UserIDContextKey, "invalid")
  86. r = r.WithContext(ctx)
  87. result := getContextInt64Value(r, UserIDContextKey)
  88. expected := int64(0)
  89. if result != expected {
  90. t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
  91. }
  92. }
  93. func TestContextInt64ValueWhenUnset(t *testing.T) {
  94. r, _ := http.NewRequest("GET", "http://example.org", nil)
  95. result := getContextInt64Value(r, UserIDContextKey)
  96. expected := int64(0)
  97. if result != expected {
  98. t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
  99. }
  100. }
  101. func TestIsAdmin(t *testing.T) {
  102. r, _ := http.NewRequest("GET", "http://example.org", nil)
  103. result := IsAdminUser(r)
  104. expected := false
  105. if result != expected {
  106. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  107. }
  108. ctx := r.Context()
  109. ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
  110. r = r.WithContext(ctx)
  111. result = IsAdminUser(r)
  112. expected = true
  113. if result != expected {
  114. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  115. }
  116. }
  117. func TestIsAuthenticated(t *testing.T) {
  118. r, _ := http.NewRequest("GET", "http://example.org", nil)
  119. result := IsAuthenticated(r)
  120. expected := false
  121. if result != expected {
  122. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  123. }
  124. ctx := r.Context()
  125. ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true)
  126. r = r.WithContext(ctx)
  127. result = IsAuthenticated(r)
  128. expected = true
  129. if result != expected {
  130. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  131. }
  132. }
  133. func TestUserID(t *testing.T) {
  134. r, _ := http.NewRequest("GET", "http://example.org", nil)
  135. result := UserID(r)
  136. expected := int64(0)
  137. if result != expected {
  138. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  139. }
  140. ctx := r.Context()
  141. ctx = context.WithValue(ctx, UserIDContextKey, int64(123))
  142. r = r.WithContext(ctx)
  143. result = UserID(r)
  144. expected = int64(123)
  145. if result != expected {
  146. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  147. }
  148. }
  149. func TestUserName(t *testing.T) {
  150. r, _ := http.NewRequest("GET", "http://example.org", nil)
  151. result := UserName(r)
  152. expected := "unknown"
  153. if result != expected {
  154. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  155. }
  156. ctx := r.Context()
  157. ctx = context.WithValue(ctx, UserNameContextKey, "jane")
  158. r = r.WithContext(ctx)
  159. result = UserName(r)
  160. expected = "jane"
  161. if result != expected {
  162. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  163. }
  164. }
  165. func TestUserTimezone(t *testing.T) {
  166. r, _ := http.NewRequest("GET", "http://example.org", nil)
  167. result := UserTimezone(r)
  168. expected := "UTC"
  169. if result != expected {
  170. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  171. }
  172. ctx := r.Context()
  173. ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris")
  174. r = r.WithContext(ctx)
  175. result = UserTimezone(r)
  176. expected = "Europe/Paris"
  177. if result != expected {
  178. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  179. }
  180. }
  181. func TestUserLanguage(t *testing.T) {
  182. r, _ := http.NewRequest("GET", "http://example.org", nil)
  183. result := UserLanguage(r)
  184. expected := "en_US"
  185. if result != expected {
  186. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  187. }
  188. ctx := r.Context()
  189. ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR")
  190. r = r.WithContext(ctx)
  191. result = UserLanguage(r)
  192. expected = "fr_FR"
  193. if result != expected {
  194. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  195. }
  196. }
  197. func TestUserTheme(t *testing.T) {
  198. r, _ := http.NewRequest("GET", "http://example.org", nil)
  199. result := UserTheme(r)
  200. expected := "system_serif"
  201. if result != expected {
  202. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  203. }
  204. ctx := r.Context()
  205. ctx = context.WithValue(ctx, UserThemeContextKey, "dark_serif")
  206. r = r.WithContext(ctx)
  207. result = UserTheme(r)
  208. expected = "dark_serif"
  209. if result != expected {
  210. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  211. }
  212. }
  213. func TestCSRF(t *testing.T) {
  214. r, _ := http.NewRequest("GET", "http://example.org", nil)
  215. result := CSRF(r)
  216. expected := ""
  217. if result != expected {
  218. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  219. }
  220. ctx := r.Context()
  221. ctx = context.WithValue(ctx, CSRFContextKey, "secret")
  222. r = r.WithContext(ctx)
  223. result = CSRF(r)
  224. expected = "secret"
  225. if result != expected {
  226. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  227. }
  228. }
  229. func TestSessionID(t *testing.T) {
  230. r, _ := http.NewRequest("GET", "http://example.org", nil)
  231. result := SessionID(r)
  232. expected := ""
  233. if result != expected {
  234. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  235. }
  236. ctx := r.Context()
  237. ctx = context.WithValue(ctx, SessionIDContextKey, "id")
  238. r = r.WithContext(ctx)
  239. result = SessionID(r)
  240. expected = "id"
  241. if result != expected {
  242. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  243. }
  244. }
  245. func TestUserSessionToken(t *testing.T) {
  246. r, _ := http.NewRequest("GET", "http://example.org", nil)
  247. result := UserSessionToken(r)
  248. expected := ""
  249. if result != expected {
  250. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  251. }
  252. ctx := r.Context()
  253. ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token")
  254. r = r.WithContext(ctx)
  255. result = UserSessionToken(r)
  256. expected = "token"
  257. if result != expected {
  258. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  259. }
  260. }
  261. func TestOAuth2State(t *testing.T) {
  262. r, _ := http.NewRequest("GET", "http://example.org", nil)
  263. result := OAuth2State(r)
  264. expected := ""
  265. if result != expected {
  266. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  267. }
  268. ctx := r.Context()
  269. ctx = context.WithValue(ctx, OAuth2StateContextKey, "state")
  270. r = r.WithContext(ctx)
  271. result = OAuth2State(r)
  272. expected = "state"
  273. if result != expected {
  274. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  275. }
  276. }
  277. func TestOAuth2CodeVerifier(t *testing.T) {
  278. r, _ := http.NewRequest("GET", "http://example.org", nil)
  279. result := OAuth2CodeVerifier(r)
  280. expected := ""
  281. if result != expected {
  282. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  283. }
  284. ctx := r.Context()
  285. ctx = context.WithValue(ctx, OAuth2CodeVerifierContextKey, "verifier")
  286. r = r.WithContext(ctx)
  287. result = OAuth2CodeVerifier(r)
  288. expected = "verifier"
  289. if result != expected {
  290. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  291. }
  292. }
  293. func TestFlashMessage(t *testing.T) {
  294. r, _ := http.NewRequest("GET", "http://example.org", nil)
  295. result := FlashMessage(r)
  296. expected := ""
  297. if result != expected {
  298. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  299. }
  300. ctx := r.Context()
  301. ctx = context.WithValue(ctx, FlashMessageContextKey, "message")
  302. r = r.WithContext(ctx)
  303. result = FlashMessage(r)
  304. expected = "message"
  305. if result != expected {
  306. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  307. }
  308. }
  309. func TestFlashErrorMessage(t *testing.T) {
  310. r, _ := http.NewRequest("GET", "http://example.org", nil)
  311. result := FlashErrorMessage(r)
  312. expected := ""
  313. if result != expected {
  314. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  315. }
  316. ctx := r.Context()
  317. ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message")
  318. r = r.WithContext(ctx)
  319. result = FlashErrorMessage(r)
  320. expected = "error message"
  321. if result != expected {
  322. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  323. }
  324. }
  325. func TestLastForceRefresh(t *testing.T) {
  326. r, _ := http.NewRequest("GET", "http://example.org", nil)
  327. result := LastForceRefresh(r)
  328. expected := time.Time{}
  329. if !result.Equal(expected) {
  330. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  331. }
  332. ctx := r.Context()
  333. ctx = context.WithValue(ctx, LastForceRefreshContextKey, "not-a-timestamp")
  334. r = r.WithContext(ctx)
  335. result = LastForceRefresh(r)
  336. expected = time.Time{}
  337. if !result.Equal(expected) {
  338. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  339. }
  340. ctx = r.Context()
  341. ctx = context.WithValue(ctx, LastForceRefreshContextKey, "1700000000")
  342. r = r.WithContext(ctx)
  343. result = LastForceRefresh(r)
  344. expected = time.Unix(1700000000, 0)
  345. if !result.Equal(expected) {
  346. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  347. }
  348. }
  349. func TestWebAuthnSessionData(t *testing.T) {
  350. r, _ := http.NewRequest("GET", "http://example.org", nil)
  351. result := WebAuthnSessionData(r)
  352. if result != nil {
  353. t.Errorf("Unexpected context value, got %v instead of nil", result)
  354. }
  355. ctx := r.Context()
  356. ctx = context.WithValue(ctx, WebAuthnDataContextKey, "invalid")
  357. r = r.WithContext(ctx)
  358. result = WebAuthnSessionData(r)
  359. if result != nil {
  360. t.Errorf("Unexpected context value, got %v instead of nil", result)
  361. }
  362. session := model.WebAuthnSession{}
  363. ctx = r.Context()
  364. ctx = context.WithValue(ctx, WebAuthnDataContextKey, session)
  365. r = r.WithContext(ctx)
  366. result = WebAuthnSessionData(r)
  367. if result == nil {
  368. t.Errorf("Unexpected context value, got nil instead of session")
  369. }
  370. }
  371. func TestClientIP(t *testing.T) {
  372. r, _ := http.NewRequest("GET", "http://example.org", nil)
  373. result := ClientIP(r)
  374. expected := ""
  375. if result != expected {
  376. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  377. }
  378. ctx := r.Context()
  379. ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1")
  380. r = r.WithContext(ctx)
  381. result = ClientIP(r)
  382. expected = "127.0.0.1"
  383. if result != expected {
  384. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  385. }
  386. }
  387. func TestGoogleReaderToken(t *testing.T) {
  388. r, _ := http.NewRequest("GET", "http://example.org", nil)
  389. result := GoogleReaderToken(r)
  390. expected := ""
  391. if result != expected {
  392. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  393. }
  394. ctx := r.Context()
  395. ctx = context.WithValue(ctx, GoogleReaderTokenKey, "token")
  396. r = r.WithContext(ctx)
  397. result = GoogleReaderToken(r)
  398. expected = "token"
  399. if result != expected {
  400. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  401. }
  402. }