middleware_test.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. package mux
  2. import (
  3. "bytes"
  4. "net/http"
  5. "testing"
  6. )
  7. type testMiddleware struct {
  8. timesCalled uint
  9. }
  10. func (tm *testMiddleware) Middleware(h http.Handler) http.Handler {
  11. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  12. tm.timesCalled++
  13. h.ServeHTTP(w, r)
  14. })
  15. }
  16. func dummyHandler(w http.ResponseWriter, r *http.Request) {}
  17. func TestMiddlewareAdd(t *testing.T) {
  18. router := NewRouter()
  19. router.HandleFunc("/", dummyHandler).Methods("GET")
  20. mw := &testMiddleware{}
  21. router.useInterface(mw)
  22. if len(router.middlewares) != 1 || router.middlewares[0] != mw {
  23. t.Fatal("Middleware was not added correctly")
  24. }
  25. router.Use(mw.Middleware)
  26. if len(router.middlewares) != 2 {
  27. t.Fatal("MiddlewareFunc method was not added correctly")
  28. }
  29. banalMw := func(handler http.Handler) http.Handler {
  30. return handler
  31. }
  32. router.Use(banalMw)
  33. if len(router.middlewares) != 3 {
  34. t.Fatal("MiddlewareFunc method was not added correctly")
  35. }
  36. }
  37. func TestMiddleware(t *testing.T) {
  38. router := NewRouter()
  39. router.HandleFunc("/", dummyHandler).Methods("GET")
  40. mw := &testMiddleware{}
  41. router.useInterface(mw)
  42. rw := NewRecorder()
  43. req := newRequest("GET", "/")
  44. // Test regular middleware call
  45. router.ServeHTTP(rw, req)
  46. if mw.timesCalled != 1 {
  47. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  48. }
  49. // Middleware should not be called for 404
  50. req = newRequest("GET", "/not/found")
  51. router.ServeHTTP(rw, req)
  52. if mw.timesCalled != 1 {
  53. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  54. }
  55. // Middleware should not be called if there is a method mismatch
  56. req = newRequest("POST", "/")
  57. router.ServeHTTP(rw, req)
  58. if mw.timesCalled != 1 {
  59. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  60. }
  61. // Add the middleware again as function
  62. router.Use(mw.Middleware)
  63. req = newRequest("GET", "/")
  64. router.ServeHTTP(rw, req)
  65. if mw.timesCalled != 3 {
  66. t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
  67. }
  68. }
  69. func TestMiddlewareSubrouter(t *testing.T) {
  70. router := NewRouter()
  71. router.HandleFunc("/", dummyHandler).Methods("GET")
  72. subrouter := router.PathPrefix("/sub").Subrouter()
  73. subrouter.HandleFunc("/x", dummyHandler).Methods("GET")
  74. mw := &testMiddleware{}
  75. subrouter.useInterface(mw)
  76. rw := NewRecorder()
  77. req := newRequest("GET", "/")
  78. router.ServeHTTP(rw, req)
  79. if mw.timesCalled != 0 {
  80. t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
  81. }
  82. req = newRequest("GET", "/sub/")
  83. router.ServeHTTP(rw, req)
  84. if mw.timesCalled != 0 {
  85. t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
  86. }
  87. req = newRequest("GET", "/sub/x")
  88. router.ServeHTTP(rw, req)
  89. if mw.timesCalled != 1 {
  90. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  91. }
  92. req = newRequest("GET", "/sub/not/found")
  93. router.ServeHTTP(rw, req)
  94. if mw.timesCalled != 1 {
  95. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  96. }
  97. router.useInterface(mw)
  98. req = newRequest("GET", "/")
  99. router.ServeHTTP(rw, req)
  100. if mw.timesCalled != 2 {
  101. t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
  102. }
  103. req = newRequest("GET", "/sub/x")
  104. router.ServeHTTP(rw, req)
  105. if mw.timesCalled != 4 {
  106. t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
  107. }
  108. }
  109. func TestMiddlewareExecution(t *testing.T) {
  110. mwStr := []byte("Middleware\n")
  111. handlerStr := []byte("Logic\n")
  112. router := NewRouter()
  113. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  114. w.Write(handlerStr)
  115. })
  116. rw := NewRecorder()
  117. req := newRequest("GET", "/")
  118. // Test handler-only call
  119. router.ServeHTTP(rw, req)
  120. if bytes.Compare(rw.Body.Bytes(), handlerStr) != 0 {
  121. t.Fatal("Handler response is not what it should be")
  122. }
  123. // Test middleware call
  124. rw = NewRecorder()
  125. router.Use(func(h http.Handler) http.Handler {
  126. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  127. w.Write(mwStr)
  128. h.ServeHTTP(w, r)
  129. })
  130. })
  131. router.ServeHTTP(rw, req)
  132. if bytes.Compare(rw.Body.Bytes(), append(mwStr, handlerStr...)) != 0 {
  133. t.Fatal("Middleware + handler response is not what it should be")
  134. }
  135. }
  136. func TestMiddlewareNotFound(t *testing.T) {
  137. mwStr := []byte("Middleware\n")
  138. handlerStr := []byte("Logic\n")
  139. router := NewRouter()
  140. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  141. w.Write(handlerStr)
  142. })
  143. router.Use(func(h http.Handler) http.Handler {
  144. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  145. w.Write(mwStr)
  146. h.ServeHTTP(w, r)
  147. })
  148. })
  149. // Test not found call with default handler
  150. rw := NewRecorder()
  151. req := newRequest("GET", "/notfound")
  152. router.ServeHTTP(rw, req)
  153. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  154. t.Fatal("Middleware was called for a 404")
  155. }
  156. // Test not found call with custom handler
  157. rw = NewRecorder()
  158. req = newRequest("GET", "/notfound")
  159. router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
  160. rw.Write([]byte("Custom 404 handler"))
  161. })
  162. router.ServeHTTP(rw, req)
  163. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  164. t.Fatal("Middleware was called for a custom 404")
  165. }
  166. }
  167. func TestMiddlewareMethodMismatch(t *testing.T) {
  168. mwStr := []byte("Middleware\n")
  169. handlerStr := []byte("Logic\n")
  170. router := NewRouter()
  171. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  172. w.Write(handlerStr)
  173. }).Methods("GET")
  174. router.Use(func(h http.Handler) http.Handler {
  175. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  176. w.Write(mwStr)
  177. h.ServeHTTP(w, r)
  178. })
  179. })
  180. // Test method mismatch
  181. rw := NewRecorder()
  182. req := newRequest("POST", "/")
  183. router.ServeHTTP(rw, req)
  184. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  185. t.Fatal("Middleware was called for a method mismatch")
  186. }
  187. // Test not found call
  188. rw = NewRecorder()
  189. req = newRequest("POST", "/")
  190. router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
  191. rw.Write([]byte("Method not allowed"))
  192. })
  193. router.ServeHTTP(rw, req)
  194. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  195. t.Fatal("Middleware was called for a method mismatch")
  196. }
  197. }
  198. func TestMiddlewareNotFoundSubrouter(t *testing.T) {
  199. mwStr := []byte("Middleware\n")
  200. handlerStr := []byte("Logic\n")
  201. router := NewRouter()
  202. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  203. w.Write(handlerStr)
  204. })
  205. subrouter := router.PathPrefix("/sub/").Subrouter()
  206. subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  207. w.Write(handlerStr)
  208. })
  209. router.Use(func(h http.Handler) http.Handler {
  210. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  211. w.Write(mwStr)
  212. h.ServeHTTP(w, r)
  213. })
  214. })
  215. // Test not found call for default handler
  216. rw := NewRecorder()
  217. req := newRequest("GET", "/sub/notfound")
  218. router.ServeHTTP(rw, req)
  219. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  220. t.Fatal("Middleware was called for a 404")
  221. }
  222. // Test not found call with custom handler
  223. rw = NewRecorder()
  224. req = newRequest("GET", "/sub/notfound")
  225. subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
  226. rw.Write([]byte("Custom 404 handler"))
  227. })
  228. router.ServeHTTP(rw, req)
  229. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  230. t.Fatal("Middleware was called for a custom 404")
  231. }
  232. }
  233. func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
  234. mwStr := []byte("Middleware\n")
  235. handlerStr := []byte("Logic\n")
  236. router := NewRouter()
  237. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  238. w.Write(handlerStr)
  239. })
  240. subrouter := router.PathPrefix("/sub/").Subrouter()
  241. subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  242. w.Write(handlerStr)
  243. }).Methods("GET")
  244. router.Use(func(h http.Handler) http.Handler {
  245. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  246. w.Write(mwStr)
  247. h.ServeHTTP(w, r)
  248. })
  249. })
  250. // Test method mismatch without custom handler
  251. rw := NewRecorder()
  252. req := newRequest("POST", "/sub/")
  253. router.ServeHTTP(rw, req)
  254. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  255. t.Fatal("Middleware was called for a method mismatch")
  256. }
  257. // Test method mismatch with custom handler
  258. rw = NewRecorder()
  259. req = newRequest("POST", "/sub/")
  260. router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
  261. rw.Write([]byte("Method not allowed"))
  262. })
  263. router.ServeHTTP(rw, req)
  264. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  265. t.Fatal("Middleware was called for a method mismatch")
  266. }
  267. }