context_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. // Copyright 2018 Frédéric Guillot. All rights reserved.
  2. // Use of this source code is governed by the Apache 2.0
  3. // license that can be found in the LICENSE file.
  4. package request // import "miniflux.app/http/request"
  5. import (
  6. "context"
  7. "net/http"
  8. "testing"
  9. )
  10. func TestContextStringValue(t *testing.T) {
  11. r, _ := http.NewRequest("GET", "http://example.org", nil)
  12. ctx := r.Context()
  13. ctx = context.WithValue(ctx, ClientIPContextKey, "IP")
  14. r = r.WithContext(ctx)
  15. result := getContextStringValue(r, ClientIPContextKey)
  16. expected := "IP"
  17. if result != expected {
  18. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  19. }
  20. }
  21. func TestContextStringValueWithInvalidType(t *testing.T) {
  22. r, _ := http.NewRequest("GET", "http://example.org", nil)
  23. ctx := r.Context()
  24. ctx = context.WithValue(ctx, ClientIPContextKey, 0)
  25. r = r.WithContext(ctx)
  26. result := getContextStringValue(r, ClientIPContextKey)
  27. expected := ""
  28. if result != expected {
  29. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  30. }
  31. }
  32. func TestContextStringValueWhenUnset(t *testing.T) {
  33. r, _ := http.NewRequest("GET", "http://example.org", nil)
  34. result := getContextStringValue(r, ClientIPContextKey)
  35. expected := ""
  36. if result != expected {
  37. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  38. }
  39. }
  40. func TestContextBoolValue(t *testing.T) {
  41. r, _ := http.NewRequest("GET", "http://example.org", nil)
  42. ctx := r.Context()
  43. ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
  44. r = r.WithContext(ctx)
  45. result := getContextBoolValue(r, IsAdminUserContextKey)
  46. expected := true
  47. if result != expected {
  48. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  49. }
  50. }
  51. func TestContextBoolValueWithInvalidType(t *testing.T) {
  52. r, _ := http.NewRequest("GET", "http://example.org", nil)
  53. ctx := r.Context()
  54. ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid")
  55. r = r.WithContext(ctx)
  56. result := getContextBoolValue(r, IsAdminUserContextKey)
  57. expected := false
  58. if result != expected {
  59. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  60. }
  61. }
  62. func TestContextBoolValueWhenUnset(t *testing.T) {
  63. r, _ := http.NewRequest("GET", "http://example.org", nil)
  64. result := getContextBoolValue(r, IsAdminUserContextKey)
  65. expected := false
  66. if result != expected {
  67. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  68. }
  69. }
  70. func TestContextInt64Value(t *testing.T) {
  71. r, _ := http.NewRequest("GET", "http://example.org", nil)
  72. ctx := r.Context()
  73. ctx = context.WithValue(ctx, UserIDContextKey, int64(1234))
  74. r = r.WithContext(ctx)
  75. result := getContextInt64Value(r, UserIDContextKey)
  76. expected := int64(1234)
  77. if result != expected {
  78. t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
  79. }
  80. }
  81. func TestContextInt64ValueWithInvalidType(t *testing.T) {
  82. r, _ := http.NewRequest("GET", "http://example.org", nil)
  83. ctx := r.Context()
  84. ctx = context.WithValue(ctx, UserIDContextKey, "invalid")
  85. r = r.WithContext(ctx)
  86. result := getContextInt64Value(r, UserIDContextKey)
  87. expected := int64(0)
  88. if result != expected {
  89. t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
  90. }
  91. }
  92. func TestContextInt64ValueWhenUnset(t *testing.T) {
  93. r, _ := http.NewRequest("GET", "http://example.org", nil)
  94. result := getContextInt64Value(r, UserIDContextKey)
  95. expected := int64(0)
  96. if result != expected {
  97. t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
  98. }
  99. }
  100. func TestIsAdmin(t *testing.T) {
  101. r, _ := http.NewRequest("GET", "http://example.org", nil)
  102. result := IsAdminUser(r)
  103. expected := false
  104. if result != expected {
  105. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  106. }
  107. ctx := r.Context()
  108. ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
  109. r = r.WithContext(ctx)
  110. result = IsAdminUser(r)
  111. expected = true
  112. if result != expected {
  113. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  114. }
  115. }
  116. func TestIsAuthenticated(t *testing.T) {
  117. r, _ := http.NewRequest("GET", "http://example.org", nil)
  118. result := IsAuthenticated(r)
  119. expected := false
  120. if result != expected {
  121. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  122. }
  123. ctx := r.Context()
  124. ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true)
  125. r = r.WithContext(ctx)
  126. result = IsAuthenticated(r)
  127. expected = true
  128. if result != expected {
  129. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  130. }
  131. }
  132. func TestUserID(t *testing.T) {
  133. r, _ := http.NewRequest("GET", "http://example.org", nil)
  134. result := UserID(r)
  135. expected := int64(0)
  136. if result != expected {
  137. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  138. }
  139. ctx := r.Context()
  140. ctx = context.WithValue(ctx, UserIDContextKey, int64(123))
  141. r = r.WithContext(ctx)
  142. result = UserID(r)
  143. expected = int64(123)
  144. if result != expected {
  145. t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
  146. }
  147. }
  148. func TestUserTimezone(t *testing.T) {
  149. r, _ := http.NewRequest("GET", "http://example.org", nil)
  150. result := UserTimezone(r)
  151. expected := "UTC"
  152. if result != expected {
  153. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  154. }
  155. ctx := r.Context()
  156. ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris")
  157. r = r.WithContext(ctx)
  158. result = UserTimezone(r)
  159. expected = "Europe/Paris"
  160. if result != expected {
  161. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  162. }
  163. }
  164. func TestUserLanguage(t *testing.T) {
  165. r, _ := http.NewRequest("GET", "http://example.org", nil)
  166. result := UserLanguage(r)
  167. expected := "en_US"
  168. if result != expected {
  169. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  170. }
  171. ctx := r.Context()
  172. ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR")
  173. r = r.WithContext(ctx)
  174. result = UserLanguage(r)
  175. expected = "fr_FR"
  176. if result != expected {
  177. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  178. }
  179. }
  180. func TestUserTheme(t *testing.T) {
  181. r, _ := http.NewRequest("GET", "http://example.org", nil)
  182. result := UserTheme(r)
  183. expected := "system_serif"
  184. if result != expected {
  185. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  186. }
  187. ctx := r.Context()
  188. ctx = context.WithValue(ctx, UserThemeContextKey, "dark_serif")
  189. r = r.WithContext(ctx)
  190. result = UserTheme(r)
  191. expected = "dark_serif"
  192. if result != expected {
  193. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  194. }
  195. }
  196. func TestCSRF(t *testing.T) {
  197. r, _ := http.NewRequest("GET", "http://example.org", nil)
  198. result := CSRF(r)
  199. expected := ""
  200. if result != expected {
  201. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  202. }
  203. ctx := r.Context()
  204. ctx = context.WithValue(ctx, CSRFContextKey, "secret")
  205. r = r.WithContext(ctx)
  206. result = CSRF(r)
  207. expected = "secret"
  208. if result != expected {
  209. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  210. }
  211. }
  212. func TestSessionID(t *testing.T) {
  213. r, _ := http.NewRequest("GET", "http://example.org", nil)
  214. result := SessionID(r)
  215. expected := ""
  216. if result != expected {
  217. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  218. }
  219. ctx := r.Context()
  220. ctx = context.WithValue(ctx, SessionIDContextKey, "id")
  221. r = r.WithContext(ctx)
  222. result = SessionID(r)
  223. expected = "id"
  224. if result != expected {
  225. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  226. }
  227. }
  228. func TestUserSessionToken(t *testing.T) {
  229. r, _ := http.NewRequest("GET", "http://example.org", nil)
  230. result := UserSessionToken(r)
  231. expected := ""
  232. if result != expected {
  233. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  234. }
  235. ctx := r.Context()
  236. ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token")
  237. r = r.WithContext(ctx)
  238. result = UserSessionToken(r)
  239. expected = "token"
  240. if result != expected {
  241. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  242. }
  243. }
  244. func TestOAuth2State(t *testing.T) {
  245. r, _ := http.NewRequest("GET", "http://example.org", nil)
  246. result := OAuth2State(r)
  247. expected := ""
  248. if result != expected {
  249. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  250. }
  251. ctx := r.Context()
  252. ctx = context.WithValue(ctx, OAuth2StateContextKey, "state")
  253. r = r.WithContext(ctx)
  254. result = OAuth2State(r)
  255. expected = "state"
  256. if result != expected {
  257. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  258. }
  259. }
  260. func TestFlashMessage(t *testing.T) {
  261. r, _ := http.NewRequest("GET", "http://example.org", nil)
  262. result := FlashMessage(r)
  263. expected := ""
  264. if result != expected {
  265. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  266. }
  267. ctx := r.Context()
  268. ctx = context.WithValue(ctx, FlashMessageContextKey, "message")
  269. r = r.WithContext(ctx)
  270. result = FlashMessage(r)
  271. expected = "message"
  272. if result != expected {
  273. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  274. }
  275. }
  276. func TestFlashErrorMessage(t *testing.T) {
  277. r, _ := http.NewRequest("GET", "http://example.org", nil)
  278. result := FlashErrorMessage(r)
  279. expected := ""
  280. if result != expected {
  281. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  282. }
  283. ctx := r.Context()
  284. ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message")
  285. r = r.WithContext(ctx)
  286. result = FlashErrorMessage(r)
  287. expected = "error message"
  288. if result != expected {
  289. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  290. }
  291. }
  292. func TestPocketRequestToken(t *testing.T) {
  293. r, _ := http.NewRequest("GET", "http://example.org", nil)
  294. result := PocketRequestToken(r)
  295. expected := ""
  296. if result != expected {
  297. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  298. }
  299. ctx := r.Context()
  300. ctx = context.WithValue(ctx, PocketRequestTokenContextKey, "request token")
  301. r = r.WithContext(ctx)
  302. result = PocketRequestToken(r)
  303. expected = "request token"
  304. if result != expected {
  305. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  306. }
  307. }
  308. func TestClientIP(t *testing.T) {
  309. r, _ := http.NewRequest("GET", "http://example.org", nil)
  310. result := ClientIP(r)
  311. expected := ""
  312. if result != expected {
  313. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  314. }
  315. ctx := r.Context()
  316. ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1")
  317. r = r.WithContext(ctx)
  318. result = ClientIP(r)
  319. expected = "127.0.0.1"
  320. if result != expected {
  321. t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
  322. }
  323. }