Browse Source

feat: Allow multiple listen addresses

This change implements the ability to specify multiple listen addresses.
This allows the application to listen on different interfaces or ports simultaneously,
or a combination of IP addresses and Unix sockets.

Closes #3343
Ingmar Stein 10 months ago
parent
commit
8fa5041c37

+ 14 - 4
internal/cli/daemon.go

@@ -33,9 +33,9 @@ func startDaemon(store *storage.Storage) {
 		runScheduler(store, pool)
 	}
 
-	var httpServer *http.Server
+	var httpServers []*http.Server
 	if config.Opts.HasHTTPService() {
-		httpServer = httpd.StartWebServer(store, pool)
+		httpServers = httpd.StartWebServer(store, pool)
 	}
 
 	if config.Opts.HasMetricsCollector() {
@@ -78,8 +78,18 @@ func startDaemon(store *storage.Storage) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
 
-	if httpServer != nil {
-		httpServer.Shutdown(ctx)
+	if len(httpServers) > 0 {
+		slog.Debug("Shutting down HTTP servers...")
+		for _, server := range httpServers {
+			if server != nil {
+				if err := server.Shutdown(ctx); err != nil {
+					slog.Error("HTTP server shutdown error", slog.Any("error", err), slog.String("addr", server.Addr))
+				}
+			}
+		}
+		slog.Debug("All HTTP servers shut down.")
+	} else {
+		slog.Debug("No HTTP servers to shut down.")
 	}
 
 	slog.Debug("Process gracefully stopped")

+ 1 - 1
internal/cli/health_check.go

@@ -14,7 +14,7 @@ import (
 
 func doHealthCheck(healthCheckEndpoint string) {
 	if healthCheckEndpoint == "auto" {
-		healthCheckEndpoint = "http://" + config.Opts.ListenAddr() + config.Opts.BasePath() + "/healthcheck"
+		healthCheckEndpoint = "http://" + config.Opts.ListenAddr()[0] + config.Opts.BasePath() + "/healthcheck"
 	}
 
 	slog.Debug("Executing health check request", slog.String("endpoint", healthCheckEndpoint))

+ 11 - 10
internal/config/config_test.go

@@ -6,6 +6,7 @@ package config // import "miniflux.app/v2/internal/config"
 import (
 	"bytes"
 	"os"
+	"reflect"
 	"testing"
 )
 
@@ -428,18 +429,18 @@ func TestListenAddr(t *testing.T) {
 		t.Fatalf(`Parsing failure: %v`, err)
 	}
 
-	expected := "foobar"
+	expected := []string{"foobar"}
 	result := opts.ListenAddr()
 
-	if result != expected {
-		t.Fatalf(`Unexpected LISTEN_ADDR value, got %q instead of %q`, result, expected)
+	if !reflect.DeepEqual(result, expected) {
+		t.Fatalf(`Unexpected LISTEN_ADDR value, got %v instead of %v`, result, expected)
 	}
 }
 
 func TestListenAddrWithPortDefined(t *testing.T) {
 	os.Clearenv()
 	os.Setenv("PORT", "3000")
-	os.Setenv("LISTEN_ADDR", "foobar")
+	os.Setenv("LISTEN_ADDR", "foobar") // This should be overridden by PORT
 
 	parser := NewParser()
 	opts, err := parser.ParseEnvironmentVariables()
@@ -447,11 +448,11 @@ func TestListenAddrWithPortDefined(t *testing.T) {
 		t.Fatalf(`Parsing failure: %v`, err)
 	}
 
-	expected := ":3000"
+	expected := []string{":3000"}
 	result := opts.ListenAddr()
 
-	if result != expected {
-		t.Fatalf(`Unexpected LISTEN_ADDR value, got %q instead of %q`, result, expected)
+	if !reflect.DeepEqual(result, expected) {
+		t.Fatalf(`Unexpected LISTEN_ADDR value when PORT is set, got %v instead of %v`, result, expected)
 	}
 }
 
@@ -464,11 +465,11 @@ func TestDefaultListenAddrValue(t *testing.T) {
 		t.Fatalf(`Parsing failure: %v`, err)
 	}
 
-	expected := defaultListenAddr
+	expected := []string{defaultListenAddr}
 	result := opts.ListenAddr()
 
-	if result != expected {
-		t.Fatalf(`Unexpected LISTEN_ADDR value, got %q instead of %q`, result, expected)
+	if !reflect.DeepEqual(result, expected) {
+		t.Fatalf(`Unexpected default LISTEN_ADDR value, got %v instead of %v`, result, expected)
 	}
 }
 

+ 4 - 4
internal/config/options.go

@@ -119,7 +119,7 @@ type Options struct {
 	databaseMinConns                   int
 	databaseConnectionLifetime         int
 	runMigrations                      bool
-	listenAddr                         string
+	listenAddr                         []string
 	certFile                           string
 	certDomain                         string
 	certKeyFile                        string
@@ -202,7 +202,7 @@ func NewOptions() *Options {
 		databaseMinConns:                   defaultDatabaseMinConns,
 		databaseConnectionLifetime:         defaultDatabaseConnectionLifetime,
 		runMigrations:                      defaultRunMigrations,
-		listenAddr:                         defaultListenAddr,
+		listenAddr:                         []string{defaultListenAddr},
 		certFile:                           defaultCertFile,
 		certDomain:                         defaultCertDomain,
 		certKeyFile:                        defaultKeyFile,
@@ -339,7 +339,7 @@ func (o *Options) DatabaseConnectionLifetime() time.Duration {
 }
 
 // ListenAddr returns the listen address for the HTTP server.
-func (o *Options) ListenAddr() string {
+func (o *Options) ListenAddr() []string {
 	return o.listenAddr
 }
 
@@ -740,7 +740,7 @@ func (o *Options) SortedOptions(redactSecret bool) []*Option {
 		"HTTP_SERVICE":                           o.httpService,
 		"INVIDIOUS_INSTANCE":                     o.invidiousInstance,
 		"KEY_FILE":                               o.certKeyFile,
-		"LISTEN_ADDR":                            o.listenAddr,
+		"LISTEN_ADDR":                            strings.Join(o.listenAddr, ","),
 		"LOG_FILE":                               o.logFile,
 		"LOG_DATE_TIME":                          o.logDateTime,
 		"LOG_FORMAT":                             o.logFormat,

+ 6 - 2
internal/config/parser.go

@@ -94,7 +94,7 @@ func (p *Parser) parseLines(lines []string) (err error) {
 		case "PORT":
 			port = value
 		case "LISTEN_ADDR":
-			p.opts.listenAddr = parseString(value, defaultListenAddr)
+			p.opts.listenAddr = parseStringList(value, []string{defaultListenAddr})
 		case "DATABASE_URL":
 			p.opts.databaseURL = parseString(value, defaultDatabaseURL)
 		case "DATABASE_URL_FILE":
@@ -258,7 +258,7 @@ func (p *Parser) parseLines(lines []string) (err error) {
 	}
 
 	if port != "" {
-		p.opts.listenAddr = ":" + port
+		p.opts.listenAddr = []string{":" + port}
 	}
 
 	youtubeEmbedURL, err := url.Parse(p.opts.youTubeEmbedUrlOverride)
@@ -339,6 +339,10 @@ func parseStringList(value string, fallback []string) []string {
 	for _, item := range items {
 		itemValue := strings.TrimSpace(item)
 
+		if itemValue == "" {
+			continue
+		}
+
 		if _, found := strMap[itemValue]; !found {
 			strMap[itemValue] = true
 			strList = append(strList, itemValue)

+ 106 - 0
internal/config/parser_test.go

@@ -4,6 +4,7 @@
 package config // import "miniflux.app/v2/internal/config"
 
 import (
+	"reflect"
 	"testing"
 )
 
@@ -58,3 +59,108 @@ func TestParseIntValue(t *testing.T) {
 		t.Errorf(`Defined variables should returns the specified value`)
 	}
 }
+
+func TestParseListenAddr(t *testing.T) {
+	defaultExpected := []string{defaultListenAddr}
+
+	tests := []struct {
+		name           string
+		listenAddr     string
+		port           string
+		expected       []string
+		lines          []string // Used for direct lines parsing instead of individual env vars
+		isLineOriented bool     // Flag to indicate if we use lines
+	}{
+		{
+			name:       "Single LISTEN_ADDR",
+			listenAddr: "127.0.0.1:8080",
+			expected:   []string{"127.0.0.1:8080"},
+		},
+		{
+			name:       "Multiple LISTEN_ADDR comma-separated",
+			listenAddr: "127.0.0.1:8080,:8081,/tmp/miniflux.sock",
+			expected:   []string{"127.0.0.1:8080", ":8081", "/tmp/miniflux.sock"},
+		},
+		{
+			name:       "Multiple LISTEN_ADDR with spaces around commas",
+			listenAddr: "127.0.0.1:8080 , :8081",
+			expected:   []string{"127.0.0.1:8080", ":8081"},
+		},
+		{
+			name:       "Empty LISTEN_ADDR",
+			listenAddr: "",
+			expected:   defaultExpected,
+		},
+		{
+			name:       "PORT overrides LISTEN_ADDR",
+			listenAddr: "127.0.0.1:8000",
+			port:       "8082",
+			expected:   []string{":8082"},
+		},
+		{
+			name:       "PORT overrides empty LISTEN_ADDR",
+			listenAddr: "",
+			port:       "8083",
+			expected:   []string{":8083"},
+		},
+		{
+			name:       "LISTEN_ADDR with empty segment (comma)",
+			listenAddr: "127.0.0.1:8080,,:8081",
+			expected:   []string{"127.0.0.1:8080", ":8081"},
+		},
+		{
+			name:           "PORT override with lines parsing",
+			isLineOriented: true,
+			lines:          []string{"LISTEN_ADDR=127.0.0.1:8000", "PORT=8082"},
+			expected:       []string{":8082"},
+		},
+		{
+			name:           "LISTEN_ADDR only with lines parsing (comma)",
+			isLineOriented: true,
+			lines:          []string{"LISTEN_ADDR=10.0.0.1:9090,10.0.0.2:9091"},
+			expected:       []string{"10.0.0.1:9090", "10.0.0.2:9091"},
+		},
+		{
+			name:           "Empty LISTEN_ADDR with lines parsing (default)",
+			isLineOriented: true,
+			lines:          []string{"LISTEN_ADDR="},
+			expected:       defaultExpected,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			parser := NewParser()
+			var err error
+
+			if tt.isLineOriented {
+				err = parser.parseLines(tt.lines)
+			} else {
+				// Simulate os.Environ() behaviour for individual var testing
+				var envLines []string
+				if tt.listenAddr != "" {
+					envLines = append(envLines, "LISTEN_ADDR="+tt.listenAddr)
+				}
+				if tt.port != "" {
+					envLines = append(envLines, "PORT="+tt.port)
+				}
+				// Add a dummy var if both are empty to avoid empty lines slice if not intended
+				if tt.listenAddr == "" && tt.port == "" && tt.name == "Empty LISTEN_ADDR" {
+					// This case specifically tests empty LISTEN_ADDR resulting in default
+					// So, we pass LISTEN_ADDR=
+					envLines = append(envLines, "LISTEN_ADDR=")
+				}
+				err = parser.parseLines(envLines)
+			}
+
+			if err != nil {
+				t.Fatalf("parseLines() error = %v", err)
+			}
+
+			opts := parser.opts
+			if !reflect.DeepEqual(opts.ListenAddr(), tt.expected) {
+				t.Errorf("ListenAddr() got = %v, want %v", opts.ListenAddr(), tt.expected)
+			}
+		})
+	}
+}

+ 93 - 58
internal/http/server/httpd.go

@@ -4,6 +4,7 @@
 package httpd // import "miniflux.app/v2/internal/http/server"
 
 import (
+	"crypto/tls"
 	"fmt"
 	"log/slog"
 	"net"
@@ -29,36 +30,84 @@ import (
 	"golang.org/x/crypto/acme/autocert"
 )
 
-func StartWebServer(store *storage.Storage, pool *worker.Pool) *http.Server {
+func StartWebServer(store *storage.Storage, pool *worker.Pool) []*http.Server {
+	listenAddresses := config.Opts.ListenAddr()
+	var httpServers []*http.Server
+
 	certFile := config.Opts.CertFile()
 	keyFile := config.Opts.CertKeyFile()
 	certDomain := config.Opts.CertDomain()
-	listenAddr := config.Opts.ListenAddr()
-	server := &http.Server{
-		ReadTimeout:  time.Duration(config.Opts.HTTPServerTimeout()) * time.Second,
-		WriteTimeout: time.Duration(config.Opts.HTTPServerTimeout()) * time.Second,
-		IdleTimeout:  time.Duration(config.Opts.HTTPServerTimeout()) * time.Second,
-		Handler:      setupHandler(store, pool),
-	}
+	var sharedAutocertTLSConfig *tls.Config
+
+	if certDomain != "" {
+		slog.Debug("Configuring autocert manager and shared TLS config", slog.String("domain", certDomain))
+		certManager := autocert.Manager{
+			Cache:      storage.NewCertificateCache(store),
+			Prompt:     autocert.AcceptTOS,
+			HostPolicy: autocert.HostWhitelist(certDomain),
+		}
 
-	switch {
-	case os.Getenv("LISTEN_PID") == strconv.Itoa(os.Getpid()):
-		startSystemdSocketServer(server)
-	case strings.HasPrefix(listenAddr, "/"):
-		startUnixSocketServer(server, listenAddr)
-	case certDomain != "":
-		config.Opts.HTTPS = true
-		startAutoCertTLSServer(server, certDomain, store)
-	case certFile != "" && keyFile != "":
+		sharedAutocertTLSConfig = &tls.Config{}
+		sharedAutocertTLSConfig.GetCertificate = certManager.GetCertificate
+		sharedAutocertTLSConfig.NextProtos = []string{"h2", "http/1.1", acme.ALPNProto}
+
+		challengeServer := &http.Server{
+			Handler: certManager.HTTPHandler(nil),
+			Addr:    ":http",
+		}
+		slog.Info("Starting ACME HTTP challenge server for autocert", slog.String("address", challengeServer.Addr))
+		go func() {
+			if err := challengeServer.ListenAndServe(); err != http.ErrServerClosed {
+				slog.Error("ACME HTTP challenge server failed", slog.Any("error", err))
+			}
+		}()
 		config.Opts.HTTPS = true
-		server.Addr = listenAddr
-		startTLSServer(server, certFile, keyFile)
-	default:
-		server.Addr = listenAddr
-		startHTTPServer(server)
+		httpServers = append(httpServers, challengeServer)
 	}
 
-	return server
+	for i, listenAddr := range listenAddresses {
+		server := &http.Server{
+			ReadTimeout:  time.Duration(config.Opts.HTTPServerTimeout()) * time.Second,
+			WriteTimeout: time.Duration(config.Opts.HTTPServerTimeout()) * time.Second,
+			IdleTimeout:  time.Duration(config.Opts.HTTPServerTimeout()) * time.Second,
+			Handler:      setupHandler(store, pool),
+		}
+
+		if !strings.HasPrefix(listenAddr, "/") && os.Getenv("LISTEN_PID") != strconv.Itoa(os.Getpid()) {
+			server.Addr = listenAddr
+		}
+
+		shouldAddServer := true
+
+		switch {
+		case os.Getenv("LISTEN_PID") == strconv.Itoa(os.Getpid()):
+			if i == 0 {
+				slog.Info("Starting server using systemd socket for the first listen address", slog.String("address_info", listenAddr))
+				startSystemdSocketServer(server)
+			} else {
+				slog.Warn("Systemd socket activation: Only the first listen address is used by systemd. Other addresses ignored.", slog.String("skipped_address", listenAddr))
+				shouldAddServer = false
+			}
+		case strings.HasPrefix(listenAddr, "/"): // Unix socket
+			startUnixSocketServer(server, listenAddr)
+		case certDomain != "" && (listenAddr == ":https" || (i == 0 && strings.Contains(listenAddr, ":"))):
+			server.Addr = listenAddr
+			startAutoCertTLSServer(server, sharedAutocertTLSConfig)
+		case certFile != "" && keyFile != "":
+			server.Addr = listenAddr
+			startTLSServer(server, certFile, keyFile)
+			config.Opts.HTTPS = true
+		default:
+			server.Addr = listenAddr
+			startHTTPServer(server)
+		}
+
+		if shouldAddServer {
+			httpServers = append(httpServers, server)
+		}
+	}
+
+	return httpServers
 }
 
 func startSystemdSocketServer(server *http.Server) {
@@ -71,56 +120,42 @@ func startSystemdSocketServer(server *http.Server) {
 
 		slog.Info(`Starting server using systemd socket`)
 		if err := server.Serve(listener); err != http.ErrServerClosed {
-			printErrorAndExit(`Server failed to start: %v`, err)
+			printErrorAndExit(`Systemd socket server failed to start: %v`, err)
 		}
 	}()
 }
 
 func startUnixSocketServer(server *http.Server, socketFile string) {
-	os.Remove(socketFile)
-
-	go func(sock string) {
-		listener, err := net.Listen("unix", sock)
-		if err != nil {
-			printErrorAndExit(`Server failed to start: %v`, err)
-		}
-		defer listener.Close()
+	if err := os.Remove(socketFile); err != nil && !os.IsNotExist(err) {
+		printErrorAndExit("Unable to remove existing Unix socket %s: %v", socketFile, err)
+	}
+	listener, err := net.Listen("unix", socketFile)
+	if err != nil {
+		printErrorAndExit(`Server failed to listen on Unix socket %s: %v`, socketFile, err)
+	}
 
-		if err := os.Chmod(sock, 0666); err != nil {
-			printErrorAndExit(`Unable to change socket permission: %v`, err)
-		}
+	if err := os.Chmod(socketFile, 0666); err != nil {
+		printErrorAndExit(`Unable to change socket permission for %s: %v`, socketFile, err)
+	}
 
-		slog.Info("Starting server using a Unix socket", slog.String("socket", sock))
+	go func() {
+		slog.Info("Starting server using a Unix socket", slog.String("socket", socketFile))
 		if err := server.Serve(listener); err != http.ErrServerClosed {
-			printErrorAndExit(`Server failed to start: %v`, err)
+			printErrorAndExit(fmt.Sprintf("Unix socket server failed to start on %s: %%v", socketFile), err)
 		}
-	}(socketFile)
+	}()
 }
 
-func startAutoCertTLSServer(server *http.Server, certDomain string, store *storage.Storage) {
-	server.Addr = ":https"
-	certManager := autocert.Manager{
-		Cache:      storage.NewCertificateCache(store),
-		Prompt:     autocert.AcceptTOS,
-		HostPolicy: autocert.HostWhitelist(certDomain),
-	}
-	server.TLSConfig.GetCertificate = certManager.GetCertificate
-	server.TLSConfig.NextProtos = []string{"h2", "http/1.1", acme.ALPNProto}
-
-	// Handle http-01 challenge.
-	s := &http.Server{
-		Handler: certManager.HTTPHandler(nil),
-		Addr:    ":http",
-	}
-	go s.ListenAndServe()
+func startAutoCertTLSServer(server *http.Server, autoTLSConfig *tls.Config) {
+	server.TLSConfig.GetCertificate = autoTLSConfig.GetCertificate
+	server.TLSConfig.NextProtos = autoTLSConfig.NextProtos
 
 	go func() {
 		slog.Info("Starting TLS server using automatic certificate management",
 			slog.String("listen_address", server.Addr),
-			slog.String("domain", certDomain),
 		)
 		if err := server.ListenAndServeTLS("", ""); err != http.ErrServerClosed {
-			printErrorAndExit(`Server failed to start: %v`, err)
+			printErrorAndExit(fmt.Sprintf("Autocert server failed to start on %s: %%v", server.Addr), err)
 		}
 	}()
 }
@@ -133,7 +168,7 @@ func startTLSServer(server *http.Server, certFile, keyFile string) {
 			slog.String("key_file", keyFile),
 		)
 		if err := server.ListenAndServeTLS(certFile, keyFile); err != http.ErrServerClosed {
-			printErrorAndExit(`Server failed to start: %v`, err)
+			printErrorAndExit(fmt.Sprintf("TLS server failed to start on %s: %%v", server.Addr), err)
 		}
 	}()
 }
@@ -144,7 +179,7 @@ func startHTTPServer(server *http.Server) {
 			slog.String("listen_address", server.Addr),
 		)
 		if err := server.ListenAndServe(); err != http.ErrServerClosed {
-			printErrorAndExit(`Server failed to start: %v`, err)
+			printErrorAndExit(fmt.Sprintf("HTTP server failed to start on %s: %%v", server.Addr), err)
 		}
 	}()
 }