Procházet zdrojové kódy

Add base URL validation

Frédéric Guillot před 8 roky
rodič
revize
55a1e97778
2 změnil soubory, kde provedl 75 přidání a 17 odebrání
  1. 38 16
      config/config.go
  2. 37 1
      config/config_test.go

+ 38 - 16
config/config.go

@@ -8,6 +8,7 @@ import (
 	"net/url"
 	"os"
 	"strconv"
+	"strings"
 
 	"github.com/miniflux/miniflux/logger"
 )
@@ -54,6 +55,35 @@ func (c *Config) getInt(key string, fallback int) int {
 	return v
 }
 
+func (c *Config) parseBaseURL() {
+	baseURL := os.Getenv("BASE_URL")
+	if baseURL == "" {
+		return
+	}
+
+	if baseURL[len(baseURL)-1:] == "/" {
+		baseURL = baseURL[:len(baseURL)-1]
+	}
+
+	u, err := url.Parse(baseURL)
+	if err != nil {
+		logger.Error("Invalid BASE_URL: %v", err)
+		return
+	}
+
+	scheme := strings.ToLower(u.Scheme)
+	if scheme != "https" && scheme != "http" {
+		logger.Error("Invalid BASE_URL: scheme must be http or https")
+		return
+	}
+
+	c.baseURL = baseURL
+	c.basePath = u.Path
+
+	u.Path = ""
+	c.rootURL = u.String()
+}
+
 // HasDebugMode returns true if debug mode is enabled.
 func (c *Config) HasDebugMode() bool {
 	return c.get("DEBUG", "") != ""
@@ -61,31 +91,16 @@ func (c *Config) HasDebugMode() bool {
 
 // BaseURL returns the application base URL with path.
 func (c *Config) BaseURL() string {
-	if c.baseURL == "" {
-		c.baseURL = c.get("BASE_URL", defaultBaseURL)
-		if c.baseURL[len(c.baseURL)-1:] == "/" {
-			c.baseURL = c.baseURL[:len(c.baseURL)-1]
-		}
-	}
 	return c.baseURL
 }
 
 // RootURL returns the base URL without path.
 func (c *Config) RootURL() string {
-	if c.rootURL == "" {
-		u, _ := url.Parse(c.BaseURL())
-		u.Path = ""
-		c.rootURL = u.String()
-	}
 	return c.rootURL
 }
 
 // BasePath returns the application base path according to the base URL.
 func (c *Config) BasePath() string {
-	if c.basePath == "" {
-		u, _ := url.Parse(c.BaseURL())
-		c.basePath = u.Path
-	}
 	return c.basePath
 }
 
@@ -204,5 +219,12 @@ func (c *Config) PocketConsumerKey(defaultValue string) string {
 
 // NewConfig returns a new Config.
 func NewConfig() *Config {
-	return &Config{IsHTTPS: os.Getenv("HTTPS") != ""}
+	cfg := &Config{
+		baseURL: defaultBaseURL,
+		rootURL: defaultBaseURL,
+		IsHTTPS: os.Getenv("HTTPS") != "",
+	}
+
+	cfg.parseBaseURL()
+	return cfg
 }

+ 37 - 1
config/config_test.go

@@ -56,7 +56,7 @@ func TestCustomBaseURLWithTrailingSlash(t *testing.T) {
 	}
 
 	if cfg.RootURL() != "http://example.org" {
-		t.Fatalf(`Unexpected root URL, got "%s"`, cfg.BaseURL())
+		t.Fatalf(`Unexpected root URL, got "%s"`, cfg.RootURL())
 	}
 
 	if cfg.BasePath() != "/folder" {
@@ -64,6 +64,42 @@ func TestCustomBaseURLWithTrailingSlash(t *testing.T) {
 	}
 }
 
+func TestBaseURLWithoutScheme(t *testing.T) {
+	os.Clearenv()
+	os.Setenv("BASE_URL", "example.org/folder/")
+	cfg := NewConfig()
+
+	if cfg.BaseURL() != "http://localhost" {
+		t.Fatalf(`Unexpected base URL, got "%s"`, cfg.BaseURL())
+	}
+
+	if cfg.RootURL() != "http://localhost" {
+		t.Fatalf(`Unexpected root URL, got "%s"`, cfg.RootURL())
+	}
+
+	if cfg.BasePath() != "" {
+		t.Fatalf(`Unexpected base path, got "%s"`, cfg.BasePath())
+	}
+}
+
+func TestBaseURLWithInvalidScheme(t *testing.T) {
+	os.Clearenv()
+	os.Setenv("BASE_URL", "ftp://example.org/folder/")
+	cfg := NewConfig()
+
+	if cfg.BaseURL() != "http://localhost" {
+		t.Fatalf(`Unexpected base URL, got "%s"`, cfg.BaseURL())
+	}
+
+	if cfg.RootURL() != "http://localhost" {
+		t.Fatalf(`Unexpected root URL, got "%s"`, cfg.RootURL())
+	}
+
+	if cfg.BasePath() != "" {
+		t.Fatalf(`Unexpected base path, got "%s"`, cfg.BasePath())
+	}
+}
+
 func TestDefaultBaseURL(t *testing.T) {
 	os.Clearenv()
 	cfg := NewConfig()