Просмотр исходного кода

Fixes IPv6 detection from headers (#14456)

* fixes client ip detection for v6

* adds test for get_client_ip

* Employ urlparse() to strip port numbers from IPs

---------

Co-authored-by: Jeremy Stretch <jstretch@netboxlabs.com>
Abhimanyu Saharan 2 лет назад
Родитель
Сommit
92bdaa2120
2 измененных файлов с 41 добавлено и 5 удалено
  1. 13 5
      netbox/utilities/request.py
  2. 28 0
      netbox/utilities/tests/test_request.py

+ 13 - 5
netbox/utilities/request.py

@@ -1,4 +1,5 @@
-from netaddr import IPAddress
+from netaddr import AddrFormatError, IPAddress
+from urllib.parse import urlparse
 
 __all__ = (
     'get_client_ip',
@@ -17,11 +18,18 @@ def get_client_ip(request, additional_headers=()):
     )
     for header in HTTP_HEADERS:
         if header in request.META:
-            client_ip = request.META[header].split(',')[0].partition(':')[0]
+            ip = request.META[header].split(',')[0].strip()
             try:
-                return IPAddress(client_ip)
-            except ValueError:
-                raise ValueError(f"Invalid IP address set for {header}: {client_ip}")
+                return IPAddress(ip)
+            except AddrFormatError:
+                # Parse the string with urlparse() to remove port number or any other cruft
+                ip = urlparse(f'//{ip}').hostname
+
+            try:
+                return IPAddress(ip)
+            except AddrFormatError:
+                # We did our best
+                raise ValueError(f"Invalid IP address set for {header}: {ip}")
 
     # Could not determine the client IP address from request headers
     return None

+ 28 - 0
netbox/utilities/tests/test_request.py

@@ -0,0 +1,28 @@
+from django.test import TestCase, RequestFactory
+
+from netaddr import IPAddress
+from utilities.request import get_client_ip
+
+
+class GetClientIPTests(TestCase):
+    def setUp(self):
+        self.factory = RequestFactory()
+
+    def test_ipv4_address(self):
+        request = self.factory.get('/', HTTP_X_FORWARDED_FOR='192.168.1.1')
+        self.assertEqual(get_client_ip(request), IPAddress('192.168.1.1'))
+        request = self.factory.get('/', HTTP_X_FORWARDED_FOR='192.168.1.1:8080')
+        self.assertEqual(get_client_ip(request), IPAddress('192.168.1.1'))
+
+    def test_ipv6_address(self):
+        request = self.factory.get('/', HTTP_X_FORWARDED_FOR='2001:db8::8a2e:370:7334')
+        self.assertEqual(get_client_ip(request), IPAddress('2001:db8::8a2e:370:7334'))
+        request = self.factory.get('/', HTTP_X_FORWARDED_FOR='[2001:db8::8a2e:370:7334]')
+        self.assertEqual(get_client_ip(request), IPAddress('2001:db8::8a2e:370:7334'))
+        request = self.factory.get('/', HTTP_X_FORWARDED_FOR='[2001:db8::8a2e:370:7334]:8080')
+        self.assertEqual(get_client_ip(request), IPAddress('2001:db8::8a2e:370:7334'))
+
+    def test_invalid_ip_address(self):
+        request = self.factory.get('/', HTTP_X_FORWARDED_FOR='invalid_ip')
+        with self.assertRaises(ValueError):
+            get_client_ip(request)