Bläddra i källkod

fix(request): change FindRemoteIP to fallback to 127.0.0.1

Frédéric Guillot 1 månad sedan
förälder
incheckning
46e77eccb0

+ 25 - 1
internal/database/migrations.go

@@ -435,7 +435,7 @@ var migrations = [...]func(tx *sql.Tx) error{
 
 		hasExtra := false
 		if err := tx.QueryRow(`
-			SELECT true 
+			SELECT true
 			FROM information_schema.columns
 			WHERE
 				table_name='users' AND
@@ -1403,4 +1403,28 @@ var migrations = [...]func(tx *sql.Tx) error{
 		_, err = tx.Exec(sql)
 		return err
 	},
+	func(tx *sql.Tx) (err error) {
+		_, err = tx.Exec(`UPDATE user_sessions SET ip = '127.0.0.1'::inet WHERE ip IS NULL`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`UPDATE user_sessions SET created_at = now() WHERE created_at IS NULL`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`UPDATE user_sessions SET user_agent = '' WHERE user_agent IS NULL`)
+		if err != nil {
+			return err
+		}
+		_, err = tx.Exec(`
+			ALTER TABLE user_sessions
+				ALTER COLUMN ip SET DEFAULT '127.0.0.1'::inet,
+				ALTER COLUMN ip SET NOT NULL,
+				ALTER COLUMN created_at SET DEFAULT now(),
+				ALTER COLUMN created_at SET NOT NULL,
+				ALTER COLUMN user_agent SET DEFAULT '',
+				ALTER COLUMN user_agent SET NOT NULL
+		`)
+		return err
+	},
 }

+ 25 - 9
internal/http/request/client_ip.go

@@ -11,8 +11,8 @@ import (
 
 // IsTrustedIP reports whether the given remote IP address belongs to one of the trusted networks.
 func IsTrustedIP(remoteIP string, trustedNetworks []string) bool {
-	if remoteIP == "@" || strings.HasPrefix(remoteIP, "/") {
-		return true
+	if len(trustedNetworks) == 0 {
+		return false
 	}
 
 	ip := net.ParseIP(remoteIP)
@@ -57,19 +57,35 @@ func FindClientIP(r *http.Request, isTrustedProxyClient bool) string {
 	return FindRemoteIP(r)
 }
 
-// FindRemoteIP returns the remote client IP address without considering HTTP headers.
+// FindRemoteIP returns the parsed remote IP address from the request,
+// falling back to 127.0.0.1 if the address is empty, a unix socket, or invalid.
 func FindRemoteIP(r *http.Request) string {
-	remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
+	if r.RemoteAddr == "@" || r.RemoteAddr == "" {
+		return "127.0.0.1"
+	}
+
+	// If it looks like it has a port (IPv4:port or [IPv6]:port), try to split it.
+	ip, _, err := net.SplitHostPort(r.RemoteAddr)
 	if err != nil {
-		remoteIP = r.RemoteAddr
+		// No port — could be a bare IPv4, IPv6, or IPv6 with zone.
+		ip = r.RemoteAddr
 	}
-	return dropIPv6zone(remoteIP)
+
+	// Strip IPv6 zone identifier if present (e.g., %eth0).
+	ip = dropIPv6zone(ip)
+
+	// Validate the IP address.
+	if net.ParseIP(ip) == nil {
+		return "127.0.0.1"
+	}
+
+	return ip
 }
 
 func dropIPv6zone(address string) string {
-	i := strings.IndexByte(address, '%')
-	if i != -1 {
-		address = address[:i]
+	idx := strings.IndexByte(address, '%')
+	if idx != -1 {
+		address = address[:idx]
 	}
 	return address
 }

+ 35 - 10
internal/http/request/client_ip_test.go

@@ -33,6 +33,16 @@ func TestFindClientIPWithoutHeaders(t *testing.T) {
 	if ip := FindClientIP(r, false); ip != "fe80::14c2:f039:edc7:edc7" {
 		t.Fatalf(`Unexpected result, got: %q`, ip)
 	}
+
+	r = &http.Request{RemoteAddr: "@"}
+	if ip := FindClientIP(r, false); ip != "127.0.0.1" {
+		t.Fatalf(`Unexpected result, got: %q`, ip)
+	}
+
+	r = &http.Request{RemoteAddr: ""}
+	if ip := FindClientIP(r, false); ip != "127.0.0.1" {
+		t.Fatalf(`Unexpected result, got: %q`, ip)
+	}
 }
 
 func TestFindClientIPWithXFFHeader(t *testing.T) {
@@ -104,14 +114,6 @@ func TestClientIPWithBothHeaders(t *testing.T) {
 	}
 }
 
-func TestClientIPWithUnixSocketRemoteAddress(t *testing.T) {
-	r := &http.Request{RemoteAddr: "@"}
-
-	if ip := FindClientIP(r, false); ip != "@" {
-		t.Fatalf(`Unexpected result, got: %q`, ip)
-	}
-}
-
 func TestClientIPWithUnixSocketRemoteAddrAndBothHeaders(t *testing.T) {
 	headers := http.Header{}
 	headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
@@ -136,8 +138,9 @@ func TestIsTrustedIP(t *testing.T) {
 		{"::1", true},
 		{"192.168.1.1", false},
 		{"invalid", false},
-		{"@", true},
-		{"/tmp/miniflux.sock", true},
+		{"@", false},
+		{"/tmp/miniflux.sock", false},
+		{"", false},
 	}
 
 	for _, scenario := range scenarios {
@@ -155,3 +158,25 @@ func TestIsTrustedIP(t *testing.T) {
 		t.Error("Expected false when trusted networks list is empty")
 	}
 }
+
+func TestFindRemoteIP(t *testing.T) {
+	scenarios := []struct {
+		ip       string
+		expected string
+	}{
+		{"192.168.0.1:4242", "192.168.0.1"},
+		{"[2001:db8::1]:4242", "2001:db8::1"},
+		{"fe80::14c2:f039:edc7:edc7%eth0", "fe80::14c2:f039:edc7:edc7"},
+		{"", "127.0.0.1"},
+		{"@", "127.0.0.1"},
+		{"invalid", "127.0.0.1"},
+	}
+
+	for _, scenario := range scenarios {
+		r := &http.Request{RemoteAddr: scenario.ip}
+		result := FindRemoteIP(r)
+		if result != scenario.expected {
+			t.Errorf("Expected %q for RemoteAddr %q, got %q", scenario.expected, scenario.ip, result)
+		}
+	}
+}

+ 4 - 1
internal/storage/user_session.go

@@ -58,6 +58,9 @@ func (s *Storage) UserSessions(userID int64) ([]model.UserSession, error) {
 // CreateUserSessionFromUsername creates a new user session.
 func (s *Storage) CreateUserSessionFromUsername(username, userAgent, ip string) (sessionID string, userID int64, err error) {
 	token := rand.Text()
+	if ip == "" {
+		ip = "127.0.0.1"
+	}
 
 	tx, err := s.db.Begin()
 	if err != nil {
@@ -100,7 +103,7 @@ func (s *Storage) UserSessionByToken(token string) (*model.UserSession, error) {
 			token,
 			created_at,
 			user_agent,
-			ip 
+			ip
 		FROM
 			user_sessions
 		WHERE