client_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. // SPDX-FileCopyrightText: Copyright The Miniflux Authors. All rights reserved.
  2. // SPDX-License-Identifier: Apache-2.0
  3. package client
  4. import (
  5. "errors"
  6. "net"
  7. "net/http"
  8. "net/http/httptest"
  9. "testing"
  10. "time"
  11. )
  12. func TestNewClientWithoutBlockingPrivateNetworks(t *testing.T) {
  13. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  14. w.WriteHeader(http.StatusOK)
  15. }))
  16. defer server.Close()
  17. client := NewClientWithOptions(Options{Timeout: 5 * time.Second})
  18. resp, err := client.Get(server.URL)
  19. if err != nil {
  20. t.Fatalf("Expected no error, got %v", err)
  21. }
  22. defer resp.Body.Close()
  23. if resp.StatusCode != http.StatusOK {
  24. t.Fatalf("Expected status 200, got %d", resp.StatusCode)
  25. }
  26. }
  27. func TestBlockPrivateNetworksBlocksLoopback(t *testing.T) {
  28. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  29. w.WriteHeader(http.StatusOK)
  30. }))
  31. defer server.Close()
  32. client := NewClientWithOptions(Options{Timeout: 5 * time.Second, BlockPrivateNetworks: true})
  33. _, err := client.Get(server.URL)
  34. if err == nil {
  35. t.Fatal("Expected an error when connecting to loopback address, got nil")
  36. }
  37. if !errors.Is(err, ErrPrivateNetwork) {
  38. t.Fatalf("Expected ErrPrivateNetwork, got %v", err)
  39. }
  40. }
  41. func TestBlockPrivateNetworksAllowsPublicIPs(t *testing.T) {
  42. client := NewClientWithOptions(Options{Timeout: 5 * time.Second, BlockPrivateNetworks: true})
  43. if client == nil {
  44. t.Fatal("Expected non-nil client")
  45. }
  46. transport, ok := client.Transport.(*http.Transport)
  47. if !ok {
  48. t.Fatal("Expected custom http.Transport when blockPrivateNetworks is true")
  49. }
  50. if transport.DialContext == nil {
  51. t.Fatal("Expected custom DialContext when blockPrivateNetworks is true")
  52. }
  53. }
  54. func TestNoCustomTransportWhenNotBlocking(t *testing.T) {
  55. client := NewClientWithOptions(Options{Timeout: 5 * time.Second})
  56. if client.Transport != nil {
  57. t.Fatal("Expected nil transport when blockPrivateNetworks is false")
  58. }
  59. }
  60. func TestBlockPrivateNetworksBlocksPrivateIP(t *testing.T) {
  61. listener, err := net.Listen("tcp", "127.0.0.1:0")
  62. if err != nil {
  63. t.Fatalf("Failed to create listener: %v", err)
  64. }
  65. defer listener.Close()
  66. server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  67. w.WriteHeader(http.StatusOK)
  68. }))
  69. server.Listener = listener
  70. server.Start()
  71. defer server.Close()
  72. client := NewClientWithOptions(Options{Timeout: 5 * time.Second, BlockPrivateNetworks: true})
  73. _, err = client.Get(server.URL)
  74. if err == nil {
  75. t.Fatal("Expected error when connecting to private IP")
  76. }
  77. if !errors.Is(err, ErrPrivateNetwork) {
  78. t.Fatalf("Expected ErrPrivateNetwork, got: %v", err)
  79. }
  80. }
  81. func TestBlockPrivateNetworksAllowsLoopbackWhenDisabled(t *testing.T) {
  82. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  83. w.WriteHeader(http.StatusOK)
  84. }))
  85. defer server.Close()
  86. client := NewClientWithOptions(Options{Timeout: 5 * time.Second})
  87. resp, err := client.Get(server.URL)
  88. if err != nil {
  89. t.Fatalf("Expected no error when blockPrivateNetworks is false, got %v", err)
  90. }
  91. defer resp.Body.Close()
  92. if resp.StatusCode != http.StatusOK {
  93. t.Fatalf("Expected status 200, got %d", resp.StatusCode)
  94. }
  95. }