Parcourir la source

test(request): add 100% unit test coverage

Frédéric Guillot il y a 2 mois
Parent
commit
3a232d0c8d

+ 3 - 3
internal/http/request/client_ip.go

@@ -9,7 +9,7 @@ import (
 	"strings"
 )
 
-// IsTrustedIP checks if the given remote IP belongs to one of the trusted networks.
+// IsTrustedIP reports whether the given remote IP address belongs to one of the trusted networks.
 func IsTrustedIP(remoteIP string, trustedNetworks []string) bool {
 	if remoteIP == "@" || strings.HasPrefix(remoteIP, "/") {
 		return true
@@ -34,7 +34,7 @@ func IsTrustedIP(remoteIP string, trustedNetworks []string) bool {
 	return false
 }
 
-// FindClientIP returns the client real IP address based on trusted Reverse-Proxy HTTP headers.
+// FindClientIP returns the real client IP address using trusted reverse-proxy headers when allowed.
 func FindClientIP(r *http.Request, isTrustedProxyClient bool) string {
 	if isTrustedProxyClient {
 		headers := [...]string{"X-Forwarded-For", "X-Real-Ip"}
@@ -57,7 +57,7 @@ func FindClientIP(r *http.Request, isTrustedProxyClient bool) string {
 	return FindRemoteIP(r)
 }
 
-// FindRemoteIP returns remote client IP address without considering HTTP headers.
+// FindRemoteIP returns the remote client IP address without considering HTTP headers.
 func FindRemoteIP(r *http.Request) string {
 	remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
 	if err != nil {

+ 18 - 16
internal/http/request/context.go

@@ -36,6 +36,7 @@ const (
 	WebAuthnDataContextKey
 )
 
+// WebAuthnSessionData returns WebAuthn session data from the request context, or nil if absent.
 func WebAuthnSessionData(r *http.Request) *model.WebAuthnSession {
 	if v := r.Context().Value(WebAuthnDataContextKey); v != nil {
 		if value, valid := v.(model.WebAuthnSession); valid {
@@ -45,27 +46,27 @@ func WebAuthnSessionData(r *http.Request) *model.WebAuthnSession {
 	return nil
 }
 
-// GoogleReaderToken returns the google reader token if it exists.
+// GoogleReaderToken returns the Google Reader token from the request context, if present.
 func GoogleReaderToken(r *http.Request) string {
 	return getContextStringValue(r, GoogleReaderTokenKey)
 }
 
-// IsAdminUser checks if the logged user is administrator.
+// IsAdminUser reports whether the logged-in user is an administrator.
 func IsAdminUser(r *http.Request) bool {
 	return getContextBoolValue(r, IsAdminUserContextKey)
 }
 
-// IsAuthenticated returns a boolean if the user is authenticated.
+// IsAuthenticated reports whether the user is authenticated.
 func IsAuthenticated(r *http.Request) bool {
 	return getContextBoolValue(r, IsAuthenticatedContextKey)
 }
 
-// UserID returns the UserID of the logged user.
+// UserID returns the logged-in user's ID from the request context.
 func UserID(r *http.Request) int64 {
 	return getContextInt64Value(r, UserIDContextKey)
 }
 
-// UserName returns the username of the logged user.
+// UserName returns the logged-in user's username, or "unknown" when unset.
 func UserName(r *http.Request) string {
 	value := getContextStringValue(r, UserNameContextKey)
 	if value == "" {
@@ -74,7 +75,7 @@ func UserName(r *http.Request) string {
 	return value
 }
 
-// UserTimezone returns the timezone used by the logged user.
+// UserTimezone returns the user's timezone, defaulting to "UTC" when unset.
 func UserTimezone(r *http.Request) string {
 	value := getContextStringValue(r, UserTimezoneContextKey)
 	if value == "" {
@@ -83,7 +84,7 @@ func UserTimezone(r *http.Request) string {
 	return value
 }
 
-// UserLanguage get the locale used by the current logged user.
+// UserLanguage returns the user's locale, defaulting to "en_US" when unset.
 func UserLanguage(r *http.Request) string {
 	language := getContextStringValue(r, UserLanguageContextKey)
 	if language == "" {
@@ -92,7 +93,7 @@ func UserLanguage(r *http.Request) string {
 	return language
 }
 
-// UserTheme get the theme used by the current logged user.
+// UserTheme returns the user's theme, defaulting to "system_serif" when unset.
 func UserTheme(r *http.Request) string {
 	theme := getContextStringValue(r, UserThemeContextKey)
 	if theme == "" {
@@ -101,41 +102,42 @@ func UserTheme(r *http.Request) string {
 	return theme
 }
 
-// CSRF returns the current CSRF token.
+// CSRF returns the CSRF token from the request context.
 func CSRF(r *http.Request) string {
 	return getContextStringValue(r, CSRFContextKey)
 }
 
-// SessionID returns the current session ID.
+// SessionID returns the current session ID from the request context.
 func SessionID(r *http.Request) string {
 	return getContextStringValue(r, SessionIDContextKey)
 }
 
-// UserSessionToken returns the current user session token.
+// UserSessionToken returns the current user session token from the request context.
 func UserSessionToken(r *http.Request) string {
 	return getContextStringValue(r, UserSessionTokenContextKey)
 }
 
-// OAuth2State returns the current OAuth2 state.
+// OAuth2State returns the OAuth2 state value from the request context.
 func OAuth2State(r *http.Request) string {
 	return getContextStringValue(r, OAuth2StateContextKey)
 }
 
+// OAuth2CodeVerifier returns the OAuth2 PKCE code verifier from the request context.
 func OAuth2CodeVerifier(r *http.Request) string {
 	return getContextStringValue(r, OAuth2CodeVerifierContextKey)
 }
 
-// FlashMessage returns the message message if any.
+// FlashMessage returns the flash message from the request context, if any.
 func FlashMessage(r *http.Request) string {
 	return getContextStringValue(r, FlashMessageContextKey)
 }
 
-// FlashErrorMessage returns the message error message if any.
+// FlashErrorMessage returns the flash error message from the request context, if any.
 func FlashErrorMessage(r *http.Request) string {
 	return getContextStringValue(r, FlashErrorMessageContextKey)
 }
 
-// LastForceRefresh returns the last force refresh timestamp.
+// LastForceRefresh returns the last force refresh timestamp from the request context.
 func LastForceRefresh(r *http.Request) time.Time {
 	jsonStringValue := getContextStringValue(r, LastForceRefreshContextKey)
 	timestamp, err := strconv.ParseInt(jsonStringValue, 10, 64)
@@ -145,7 +147,7 @@ func LastForceRefresh(r *http.Request) time.Time {
 	return time.Unix(timestamp, 0)
 }
 
-// ClientIP returns the client IP address stored in the context.
+// ClientIP returns the client IP address stored in the request context.
 func ClientIP(r *http.Request) string {
 	return getContextStringValue(r, ClientIPContextKey)
 }

+ 130 - 0
internal/http/request/context_test.go

@@ -7,6 +7,9 @@ import (
 	"context"
 	"net/http"
 	"testing"
+	"time"
+
+	"miniflux.app/v2/internal/model"
 )
 
 func TestContextStringValue(t *testing.T) {
@@ -192,6 +195,28 @@ func TestUserID(t *testing.T) {
 	}
 }
 
+func TestUserName(t *testing.T) {
+	r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+	result := UserName(r)
+	expected := "unknown"
+
+	if result != expected {
+		t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+	}
+
+	ctx := r.Context()
+	ctx = context.WithValue(ctx, UserNameContextKey, "jane")
+	r = r.WithContext(ctx)
+
+	result = UserName(r)
+	expected = "jane"
+
+	if result != expected {
+		t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+	}
+}
+
 func TestUserTimezone(t *testing.T) {
 	r, _ := http.NewRequest("GET", "http://example.org", nil)
 
@@ -346,6 +371,28 @@ func TestOAuth2State(t *testing.T) {
 	}
 }
 
+func TestOAuth2CodeVerifier(t *testing.T) {
+	r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+	result := OAuth2CodeVerifier(r)
+	expected := ""
+
+	if result != expected {
+		t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+	}
+
+	ctx := r.Context()
+	ctx = context.WithValue(ctx, OAuth2CodeVerifierContextKey, "verifier")
+	r = r.WithContext(ctx)
+
+	result = OAuth2CodeVerifier(r)
+	expected = "verifier"
+
+	if result != expected {
+		t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+	}
+}
+
 func TestFlashMessage(t *testing.T) {
 	r, _ := http.NewRequest("GET", "http://example.org", nil)
 
@@ -390,6 +437,67 @@ func TestFlashErrorMessage(t *testing.T) {
 	}
 }
 
+func TestLastForceRefresh(t *testing.T) {
+	r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+	result := LastForceRefresh(r)
+	expected := time.Time{}
+
+	if !result.Equal(expected) {
+		t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+	}
+
+	ctx := r.Context()
+	ctx = context.WithValue(ctx, LastForceRefreshContextKey, "not-a-timestamp")
+	r = r.WithContext(ctx)
+
+	result = LastForceRefresh(r)
+	expected = time.Time{}
+
+	if !result.Equal(expected) {
+		t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+	}
+
+	ctx = r.Context()
+	ctx = context.WithValue(ctx, LastForceRefreshContextKey, "1700000000")
+	r = r.WithContext(ctx)
+
+	result = LastForceRefresh(r)
+	expected = time.Unix(1700000000, 0)
+
+	if !result.Equal(expected) {
+		t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
+	}
+}
+
+func TestWebAuthnSessionData(t *testing.T) {
+	r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+	result := WebAuthnSessionData(r)
+	if result != nil {
+		t.Errorf("Unexpected context value, got %v instead of nil", result)
+	}
+
+	ctx := r.Context()
+	ctx = context.WithValue(ctx, WebAuthnDataContextKey, "invalid")
+	r = r.WithContext(ctx)
+
+	result = WebAuthnSessionData(r)
+	if result != nil {
+		t.Errorf("Unexpected context value, got %v instead of nil", result)
+	}
+
+	session := model.WebAuthnSession{}
+	ctx = r.Context()
+	ctx = context.WithValue(ctx, WebAuthnDataContextKey, session)
+	r = r.WithContext(ctx)
+
+	result = WebAuthnSessionData(r)
+	if result == nil {
+		t.Errorf("Unexpected context value, got nil instead of session")
+	}
+}
+
 func TestClientIP(t *testing.T) {
 	r, _ := http.NewRequest("GET", "http://example.org", nil)
 
@@ -411,3 +519,25 @@ func TestClientIP(t *testing.T) {
 		t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
 	}
 }
+
+func TestGoogleReaderToken(t *testing.T) {
+	r, _ := http.NewRequest("GET", "http://example.org", nil)
+
+	result := GoogleReaderToken(r)
+	expected := ""
+
+	if result != expected {
+		t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+	}
+
+	ctx := r.Context()
+	ctx = context.WithValue(ctx, GoogleReaderTokenKey, "token")
+	r = r.WithContext(ctx)
+
+	result = GoogleReaderToken(r)
+	expected = "token"
+
+	if result != expected {
+		t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
+	}
+}

+ 1 - 1
internal/http/request/cookie.go

@@ -5,7 +5,7 @@ package request // import "miniflux.app/v2/internal/http/request"
 
 import "net/http"
 
-// CookieValue returns the cookie value.
+// CookieValue returns the named cookie value, or an empty string if the cookie is missing.
 func CookieValue(r *http.Request, name string) string {
 	cookie, err := r.Cookie(name)
 	if err != nil {

+ 9 - 9
internal/http/request/params.go

@@ -11,7 +11,7 @@ import (
 	"github.com/gorilla/mux"
 )
 
-// FormInt64Value returns a form value as integer.
+// FormInt64Value returns the named form value parsed as int64, or 0 on error.
 func FormInt64Value(r *http.Request, param string) int64 {
 	value := r.FormValue(param)
 	integer, err := strconv.ParseInt(value, 10, 64)
@@ -22,7 +22,7 @@ func FormInt64Value(r *http.Request, param string) int64 {
 	return integer
 }
 
-// RouteInt64Param returns an URL route parameter as int64.
+// RouteInt64Param returns the named route parameter parsed as int64, or 0 when missing or invalid.
 func RouteInt64Param(r *http.Request, param string) int64 {
 	vars := mux.Vars(r)
 	value, err := strconv.ParseInt(vars[param], 10, 64)
@@ -37,13 +37,13 @@ func RouteInt64Param(r *http.Request, param string) int64 {
 	return value
 }
 
-// RouteStringParam returns a URL route parameter as string.
+// RouteStringParam returns the named route parameter as a string.
 func RouteStringParam(r *http.Request, param string) string {
 	vars := mux.Vars(r)
 	return vars[param]
 }
 
-// QueryStringParam returns a query string parameter as string.
+// QueryStringParam returns the named query parameter, or defaultValue if it is empty.
 func QueryStringParam(r *http.Request, param, defaultValue string) string {
 	value := r.URL.Query().Get(param)
 	if value == "" {
@@ -52,7 +52,7 @@ func QueryStringParam(r *http.Request, param, defaultValue string) string {
 	return value
 }
 
-// QueryStringParamList returns all values associated to the parameter.
+// QueryStringParamList returns the non-empty, trimmed values for the named query parameter.
 func QueryStringParamList(r *http.Request, param string) []string {
 	var results []string
 	values := r.URL.Query()
@@ -69,7 +69,7 @@ func QueryStringParamList(r *http.Request, param string) []string {
 	return results
 }
 
-// QueryIntParam returns a query string parameter as integer.
+// QueryIntParam returns the named query parameter parsed as int, or defaultValue when missing, invalid, or negative.
 func QueryIntParam(r *http.Request, param string, defaultValue int) int {
 	value := r.URL.Query().Get(param)
 	if value == "" {
@@ -88,7 +88,7 @@ func QueryIntParam(r *http.Request, param string, defaultValue int) int {
 	return int(val)
 }
 
-// QueryInt64Param returns a query string parameter as int64.
+// QueryInt64Param returns the named query parameter parsed as int64, or defaultValue when missing, invalid, or negative.
 func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
 	value := r.URL.Query().Get(param)
 	if value == "" {
@@ -107,7 +107,7 @@ func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
 	return val
 }
 
-// QueryBoolParam returns a query string parameter as bool.
+// QueryBoolParam returns the named query parameter parsed as bool, or defaultValue when missing or invalid.
 func QueryBoolParam(r *http.Request, param string, defaultValue bool) bool {
 	value := r.URL.Query().Get(param)
 	if value == "" {
@@ -123,7 +123,7 @@ func QueryBoolParam(r *http.Request, param string, defaultValue bool) bool {
 	return val
 }
 
-// HasQueryParam checks if the query string contains the given parameter.
+// HasQueryParam reports whether the query string contains the named parameter.
 func HasQueryParam(r *http.Request, param string) bool {
 	values := r.URL.Query()
 	_, ok := values[param]

+ 54 - 1
internal/http/request/params_test.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"net/url"
+	"reflect"
 	"testing"
 
 	"github.com/gorilla/mux"
@@ -179,7 +180,7 @@ func TestQueryInt64Param(t *testing.T) {
 		t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
 	}
 
-	result = QueryInt64Param(r, "invalid", int64(69))
+	result = QueryInt64Param(r, "negative", int64(69))
 	expected = int64(69)
 
 	if result != expected {
@@ -194,6 +195,58 @@ func TestQueryInt64Param(t *testing.T) {
 	}
 }
 
+func TestQueryBoolParam(t *testing.T) {
+	u, _ := url.Parse("http://example.org/?truthy=true&falsy=false&invalid=wat")
+	r := &http.Request{URL: u}
+
+	result := QueryBoolParam(r, "truthy", false)
+	expected := true
+
+	if result != expected {
+		t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+	}
+
+	result = QueryBoolParam(r, "falsy", true)
+	expected = false
+
+	if result != expected {
+		t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+	}
+
+	result = QueryBoolParam(r, "missing", true)
+	expected = true
+
+	if result != expected {
+		t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+	}
+
+	result = QueryBoolParam(r, "invalid", true)
+	expected = true
+
+	if result != expected {
+		t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+	}
+}
+
+func TestQueryStringParamList(t *testing.T) {
+	u, _ := url.Parse("http://example.org/?tag=alpha&tag=beta&tag=+&tag=%20%20gamma%20%20&empty=")
+	r := &http.Request{URL: u}
+
+	result := QueryStringParamList(r, "tag")
+	expected := []string{"alpha", "beta", "gamma"}
+
+	if !reflect.DeepEqual(result, expected) {
+		t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+	}
+
+	result = QueryStringParamList(r, "missing")
+	expected = nil
+
+	if !reflect.DeepEqual(result, expected) {
+		t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
+	}
+}
+
 func TestHasQueryParam(t *testing.T) {
 	u, _ := url.Parse("http://example.org/?key=42")
 	r := &http.Request{URL: u}