4
0
Эх сурвалжийг харах

feat(request): support ServeMux PathValue route params

Frédéric Guillot 2 долоо хоног өмнө
parent
commit
1e42cec0ee

+ 11 - 4
internal/http/request/params.go

@@ -24,8 +24,7 @@ func FormInt64Value(r *http.Request, param string) int64 {
 
 // RouteInt64Param returns the named route parameter parsed as int64, or 0 when missing or invalid.
 func RouteInt64Param(r *http.Request, param string) int64 {
-	vars := mux.Vars(r)
-	value, err := strconv.ParseInt(vars[param], 10, 64)
+	value, err := strconv.ParseInt(routeParam(r, param), 10, 64)
 	if err != nil {
 		return 0
 	}
@@ -39,8 +38,7 @@ func RouteInt64Param(r *http.Request, param string) int64 {
 
 // RouteStringParam returns the named route parameter as a string.
 func RouteStringParam(r *http.Request, param string) string {
-	vars := mux.Vars(r)
-	return vars[param]
+	return routeParam(r, param)
 }
 
 // QueryStringParam returns the named query parameter, or defaultValue if it is empty.
@@ -129,3 +127,12 @@ func HasQueryParam(r *http.Request, param string) bool {
 	_, ok := values[param]
 	return ok
 }
+
+func routeParam(r *http.Request, param string) string {
+	vars := mux.Vars(r)
+	if value, found := vars[param]; found {
+		return value
+	}
+
+	return r.PathValue(param)
+}

+ 70 - 2
internal/http/request/params_test.go

@@ -42,7 +42,7 @@ func TestFormInt64Value(t *testing.T) {
 	}
 }
 
-func TestRouteStringParam(t *testing.T) {
+func TestRouteStringParamWithGorillaMux(t *testing.T) {
 	router := mux.NewRouter()
 	router.HandleFunc("/route/{variable}/index", func(w http.ResponseWriter, r *http.Request) {
 		result := RouteStringParam(r, "variable")
@@ -69,7 +69,34 @@ func TestRouteStringParam(t *testing.T) {
 	router.ServeHTTP(w, r)
 }
 
-func TestRouteInt64Param(t *testing.T) {
+func TestRouteStringParamWithServerMux(t *testing.T) {
+	router := http.NewServeMux()
+	router.HandleFunc("GET /route/{variable}/index", func(w http.ResponseWriter, r *http.Request) {
+		result := RouteStringParam(r, "variable")
+		expected := "value"
+
+		if result != expected {
+			t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+		}
+
+		result = RouteStringParam(r, "missing variable")
+		expected = ""
+
+		if result != expected {
+			t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
+		}
+	})
+
+	r, err := http.NewRequest(http.MethodGet, "/route/value/index", nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	w := httptest.NewRecorder()
+	router.ServeHTTP(w, r)
+}
+
+func TestRouteInt64ParamWithGorillaMux(t *testing.T) {
 	router := mux.NewRouter()
 	router.HandleFunc("/a/{variable1}/b/{variable2}/c/{variable3}", func(w http.ResponseWriter, r *http.Request) {
 		result := RouteInt64Param(r, "variable1")
@@ -110,6 +137,47 @@ func TestRouteInt64Param(t *testing.T) {
 	router.ServeHTTP(w, r)
 }
 
+func TestRouteInt64ParamWithServerMux(t *testing.T) {
+	router := http.NewServeMux()
+	router.HandleFunc("GET /a/{variable1}/b/{variable2}/c/{variable3}", func(w http.ResponseWriter, r *http.Request) {
+		result := RouteInt64Param(r, "variable1")
+		expected := int64(42)
+
+		if result != expected {
+			t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+		}
+
+		result = RouteInt64Param(r, "missing variable")
+		expected = 0
+
+		if result != expected {
+			t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+		}
+
+		result = RouteInt64Param(r, "variable2")
+		expected = 0
+
+		if result != expected {
+			t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+		}
+
+		result = RouteInt64Param(r, "variable3")
+		expected = 0
+
+		if result != expected {
+			t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
+		}
+	})
+
+	r, err := http.NewRequest(http.MethodGet, "/a/42/b/not-int/c/-10", nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	w := httptest.NewRecorder()
+	router.ServeHTTP(w, r)
+}
+
 func TestQueryStringParam(t *testing.T) {
 	u, _ := url.Parse("http://example.org/?key=value")
 	r := &http.Request{URL: u}