Browse Source

Merge pull request #3847 from kobayashi/3525

Fixes #3525: Filter muiltiple ipaddress terms
Jeremy Stretch 6 years ago
parent
commit
d96f474a5f

+ 4 - 0
docs/release-notes/version-2.6.md

@@ -1,5 +1,9 @@
 # v2.6.13 (FUTURE)
 # v2.6.13 (FUTURE)
 
 
+## Enhancements
+
+* [#3525](https://github.com/netbox-community/netbox/issues/3525) - Enable IP address filtering with multiple address terms
+
 ## Bug Fixes
 ## Bug Fixes
 
 
 * [#3914](https://github.com/netbox-community/netbox/issues/3914) - Fix interface filter field when unauthenticated
 * [#3914](https://github.com/netbox-community/netbox/issues/3914) - Fix interface filter field when unauthenticated

+ 8 - 2
netbox/ipam/fields.py

@@ -1,6 +1,6 @@
 from django.core.exceptions import ValidationError
 from django.core.exceptions import ValidationError
 from django.db import models
 from django.db import models
-from netaddr import AddrFormatError, IPNetwork
+from netaddr import AddrFormatError, IPNetwork, IPAddress
 
 
 from . import lookups
 from . import lookups
 from .formfields import IPFormField
 from .formfields import IPFormField
@@ -23,7 +23,10 @@ class BaseIPField(models.Field):
         if not value:
         if not value:
             return value
             return value
         try:
         try:
-            return IPNetwork(value)
+            if '/' in str(value):
+                return IPNetwork(value)
+            else:
+                return IPAddress(value)
         except AddrFormatError as e:
         except AddrFormatError as e:
             raise ValidationError("Invalid IP address format: {}".format(value))
             raise ValidationError("Invalid IP address format: {}".format(value))
         except (TypeError, ValueError) as e:
         except (TypeError, ValueError) as e:
@@ -32,6 +35,8 @@ class BaseIPField(models.Field):
     def get_prep_value(self, value):
     def get_prep_value(self, value):
         if not value:
         if not value:
             return None
             return None
+        if isinstance(value, list):
+            return [str(self.to_python(v)) for v in value]
         return str(self.to_python(value))
         return str(self.to_python(value))
 
 
     def form_class(self):
     def form_class(self):
@@ -90,5 +95,6 @@ IPAddressField.register_lookup(lookups.NetContainedOrEqual)
 IPAddressField.register_lookup(lookups.NetContains)
 IPAddressField.register_lookup(lookups.NetContains)
 IPAddressField.register_lookup(lookups.NetContainsOrEquals)
 IPAddressField.register_lookup(lookups.NetContainsOrEquals)
 IPAddressField.register_lookup(lookups.NetHost)
 IPAddressField.register_lookup(lookups.NetHost)
+IPAddressField.register_lookup(lookups.NetIn)
 IPAddressField.register_lookup(lookups.NetHostContained)
 IPAddressField.register_lookup(lookups.NetHostContained)
 IPAddressField.register_lookup(lookups.NetMaskLength)
 IPAddressField.register_lookup(lookups.NetMaskLength)

+ 3 - 8
netbox/ipam/filters.py

@@ -7,7 +7,7 @@ from netaddr.core import AddrFormatError
 from dcim.models import Device, Interface, Region, Site
 from dcim.models import Device, Interface, Region, Site
 from extras.filters import CustomFieldFilterSet, CreatedUpdatedFilterSet
 from extras.filters import CustomFieldFilterSet, CreatedUpdatedFilterSet
 from tenancy.filtersets import TenancyFilterSet
 from tenancy.filtersets import TenancyFilterSet
-from utilities.filters import NameSlugSearchFilterSet, NumericInFilter, TagFilter, TreeNodeMultipleChoiceFilter
+from utilities.filters import NameSlugSearchFilterSet, NumericInFilter, TagFilter, TreeNodeMultipleChoiceFilter, MultiValueCharFilter
 from virtualization.models import VirtualMachine
 from virtualization.models import VirtualMachine
 from .constants import *
 from .constants import *
 from .models import Aggregate, IPAddress, Prefix, RIR, Role, Service, VLAN, VLANGroup, VRF
 from .models import Aggregate, IPAddress, Prefix, RIR, Role, Service, VLAN, VLANGroup, VRF
@@ -284,7 +284,7 @@ class IPAddressFilter(TenancyFilterSet, CustomFieldFilterSet, CreatedUpdatedFilt
         method='search_by_parent',
         method='search_by_parent',
         label='Parent prefix',
         label='Parent prefix',
     )
     )
-    address = django_filters.CharFilter(
+    address = MultiValueCharFilter(
         method='filter_address',
         method='filter_address',
         label='Address',
         label='Address',
     )
     )
@@ -371,13 +371,8 @@ class IPAddressFilter(TenancyFilterSet, CustomFieldFilterSet, CreatedUpdatedFilt
             return queryset.none()
             return queryset.none()
 
 
     def filter_address(self, queryset, name, value):
     def filter_address(self, queryset, name, value):
-        if not value.strip():
-            return queryset
         try:
         try:
-            # Match address and subnet mask
-            if '/' in value:
-                return queryset.filter(address=value)
-            return queryset.filter(address__net_host=value)
+            return queryset.filter(address__net_in=value)
         except ValidationError:
         except ValidationError:
             return queryset.none()
             return queryset.none()
 
 

+ 36 - 0
netbox/ipam/lookups.py

@@ -100,6 +100,42 @@ class NetHost(Lookup):
         return 'HOST(%s) = %s' % (lhs, rhs), params
         return 'HOST(%s) = %s' % (lhs, rhs), params
 
 
 
 
+class NetIn(Lookup):
+    lookup_name = 'net_in'
+
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        with_mask, without_mask = [], []
+        for address in rhs_params[0]:
+            if '/' in address:
+                with_mask.append(address)
+            else:
+                without_mask.append(address)
+
+        address_in_clause = self.create_in_clause('{} IN ('.format(lhs), len(with_mask))
+        host_in_clause = self.create_in_clause('HOST({}) IN ('.format(lhs), len(without_mask))
+
+        if with_mask and not without_mask:
+            return address_in_clause, with_mask
+        elif not with_mask and without_mask:
+            return host_in_clause, without_mask
+
+        in_clause = '({}) OR ({})'.format(address_in_clause, host_in_clause)
+        with_mask.extend(without_mask)
+        return in_clause, with_mask
+
+    @staticmethod
+    def create_in_clause(clause_part, max_size):
+        clause_elements = [clause_part]
+        for offset in range(0, max_size):
+            if offset > 0:
+                clause_elements.append(', ')
+            clause_elements.append('%s')
+        clause_elements.append(')')
+        return ''.join(clause_elements)
+
+
 class NetHostContained(Lookup):
 class NetHostContained(Lookup):
     """
     """
     Check for the host portion of an IP address without regard to its mask. This allows us to find e.g. 192.0.2.1/24
     Check for the host portion of an IP address without regard to its mask. This allows us to find e.g. 192.0.2.1/24

+ 17 - 11
netbox/ipam/tests/test_filters.py

@@ -337,16 +337,18 @@ class IPAddressTestCase(TestCase):
             IPAddress(family=4, address='10.0.0.2/24', vrf=vrfs[0], interface=interfaces[0], status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-b'),
             IPAddress(family=4, address='10.0.0.2/24', vrf=vrfs[0], interface=interfaces[0], status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-b'),
             IPAddress(family=4, address='10.0.0.3/24', vrf=vrfs[1], interface=interfaces[1], status=IPADDRESS_STATUS_RESERVED, role=IPADDRESS_ROLE_VIP, dns_name='ipaddress-c'),
             IPAddress(family=4, address='10.0.0.3/24', vrf=vrfs[1], interface=interfaces[1], status=IPADDRESS_STATUS_RESERVED, role=IPADDRESS_ROLE_VIP, dns_name='ipaddress-c'),
             IPAddress(family=4, address='10.0.0.4/24', vrf=vrfs[2], interface=interfaces[2], status=IPADDRESS_STATUS_DEPRECATED, role=IPADDRESS_ROLE_SECONDARY, dns_name='ipaddress-d'),
             IPAddress(family=4, address='10.0.0.4/24', vrf=vrfs[2], interface=interfaces[2], status=IPADDRESS_STATUS_DEPRECATED, role=IPADDRESS_ROLE_SECONDARY, dns_name='ipaddress-d'),
+            IPAddress(family=4, address='10.0.0.1/25', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None),
             IPAddress(family=6, address='2001:db8::1/64', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-a'),
             IPAddress(family=6, address='2001:db8::1/64', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-a'),
             IPAddress(family=6, address='2001:db8::2/64', vrf=vrfs[0], interface=interfaces[3], status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-b'),
             IPAddress(family=6, address='2001:db8::2/64', vrf=vrfs[0], interface=interfaces[3], status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-b'),
             IPAddress(family=6, address='2001:db8::3/64', vrf=vrfs[1], interface=interfaces[4], status=IPADDRESS_STATUS_RESERVED, role=IPADDRESS_ROLE_VIP, dns_name='ipaddress-c'),
             IPAddress(family=6, address='2001:db8::3/64', vrf=vrfs[1], interface=interfaces[4], status=IPADDRESS_STATUS_RESERVED, role=IPADDRESS_ROLE_VIP, dns_name='ipaddress-c'),
             IPAddress(family=6, address='2001:db8::4/64', vrf=vrfs[2], interface=interfaces[5], status=IPADDRESS_STATUS_DEPRECATED, role=IPADDRESS_ROLE_SECONDARY, dns_name='ipaddress-d'),
             IPAddress(family=6, address='2001:db8::4/64', vrf=vrfs[2], interface=interfaces[5], status=IPADDRESS_STATUS_DEPRECATED, role=IPADDRESS_ROLE_SECONDARY, dns_name='ipaddress-d'),
+            IPAddress(family=6, address='2001:db8::1/65', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None),
         )
         )
         IPAddress.objects.bulk_create(ipaddresses)
         IPAddress.objects.bulk_create(ipaddresses)
 
 
     def test_family(self):
     def test_family(self):
         params = {'family': '6'}
         params = {'family': '6'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)
 
 
     def test_dns_name(self):
     def test_dns_name(self):
         params = {'dns_name': ['ipaddress-a', 'ipaddress-b']}
         params = {'dns_name': ['ipaddress-a', 'ipaddress-b']}
@@ -359,20 +361,24 @@ class IPAddressTestCase(TestCase):
 
 
     def test_parent(self):
     def test_parent(self):
         params = {'parent': '10.0.0.0/24'}
         params = {'parent': '10.0.0.0/24'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)
         params = {'parent': '2001:db8::/64'}
         params = {'parent': '2001:db8::/64'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)
 
 
-    def filter_address(self):
+    def test_filter_address(self):
         # Check IPv4 and IPv6, with and without a mask
         # Check IPv4 and IPv6, with and without a mask
-        params = {'address': '10.0.0.1/24'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
-        params = {'address': '10.0.0.1'}
+        params = {'address': ['10.0.0.1/24']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
-        params = {'address': '2001:db8::1/64'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
-        params = {'address': '2001:db8::1'}
+        params = {'address': ['10.0.0.1']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'address': ['10.0.0.1/24', '10.0.0.1/25']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'address': ['2001:db8::1/64']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
+        params = {'address': ['2001:db8::1']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'address': ['2001:db8::1/64', '2001:db8::1/65']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
 
     def test_mask_length(self):
     def test_mask_length(self):
         params = {'mask_length': '24'}
         params = {'mask_length': '24'}
@@ -411,7 +417,7 @@ class IPAddressTestCase(TestCase):
         params = {'assigned_to_interface': 'true'}
         params = {'assigned_to_interface': 'true'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
         params = {'assigned_to_interface': 'false'}
         params = {'assigned_to_interface': 'false'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
 
 
     def test_status(self):
     def test_status(self):
         params = {'status': [PREFIX_STATUS_DEPRECATED, PREFIX_STATUS_RESERVED]}
         params = {'status': [PREFIX_STATUS_DEPRECATED, PREFIX_STATUS_RESERVED]}