Просмотр исходного кода

Merge pull request #3847 from kobayashi/3525

Fixes #3525: Filter muiltiple ipaddress terms
Jeremy Stretch 6 лет назад
Родитель
Сommit
d96f474a5f

+ 4 - 0
docs/release-notes/version-2.6.md

@@ -1,5 +1,9 @@
 # v2.6.13 (FUTURE)
 
+## Enhancements
+
+* [#3525](https://github.com/netbox-community/netbox/issues/3525) - Enable IP address filtering with multiple address terms
+
 ## Bug Fixes
 
 * [#3914](https://github.com/netbox-community/netbox/issues/3914) - Fix interface filter field when unauthenticated

+ 8 - 2
netbox/ipam/fields.py

@@ -1,6 +1,6 @@
 from django.core.exceptions import ValidationError
 from django.db import models
-from netaddr import AddrFormatError, IPNetwork
+from netaddr import AddrFormatError, IPNetwork, IPAddress
 
 from . import lookups
 from .formfields import IPFormField
@@ -23,7 +23,10 @@ class BaseIPField(models.Field):
         if not value:
             return value
         try:
-            return IPNetwork(value)
+            if '/' in str(value):
+                return IPNetwork(value)
+            else:
+                return IPAddress(value)
         except AddrFormatError as e:
             raise ValidationError("Invalid IP address format: {}".format(value))
         except (TypeError, ValueError) as e:
@@ -32,6 +35,8 @@ class BaseIPField(models.Field):
     def get_prep_value(self, value):
         if not value:
             return None
+        if isinstance(value, list):
+            return [str(self.to_python(v)) for v in value]
         return str(self.to_python(value))
 
     def form_class(self):
@@ -90,5 +95,6 @@ IPAddressField.register_lookup(lookups.NetContainedOrEqual)
 IPAddressField.register_lookup(lookups.NetContains)
 IPAddressField.register_lookup(lookups.NetContainsOrEquals)
 IPAddressField.register_lookup(lookups.NetHost)
+IPAddressField.register_lookup(lookups.NetIn)
 IPAddressField.register_lookup(lookups.NetHostContained)
 IPAddressField.register_lookup(lookups.NetMaskLength)

+ 3 - 8
netbox/ipam/filters.py

@@ -7,7 +7,7 @@ from netaddr.core import AddrFormatError
 from dcim.models import Device, Interface, Region, Site
 from extras.filters import CustomFieldFilterSet, CreatedUpdatedFilterSet
 from tenancy.filtersets import TenancyFilterSet
-from utilities.filters import NameSlugSearchFilterSet, NumericInFilter, TagFilter, TreeNodeMultipleChoiceFilter
+from utilities.filters import NameSlugSearchFilterSet, NumericInFilter, TagFilter, TreeNodeMultipleChoiceFilter, MultiValueCharFilter
 from virtualization.models import VirtualMachine
 from .constants import *
 from .models import Aggregate, IPAddress, Prefix, RIR, Role, Service, VLAN, VLANGroup, VRF
@@ -284,7 +284,7 @@ class IPAddressFilter(TenancyFilterSet, CustomFieldFilterSet, CreatedUpdatedFilt
         method='search_by_parent',
         label='Parent prefix',
     )
-    address = django_filters.CharFilter(
+    address = MultiValueCharFilter(
         method='filter_address',
         label='Address',
     )
@@ -371,13 +371,8 @@ class IPAddressFilter(TenancyFilterSet, CustomFieldFilterSet, CreatedUpdatedFilt
             return queryset.none()
 
     def filter_address(self, queryset, name, value):
-        if not value.strip():
-            return queryset
         try:
-            # Match address and subnet mask
-            if '/' in value:
-                return queryset.filter(address=value)
-            return queryset.filter(address__net_host=value)
+            return queryset.filter(address__net_in=value)
         except ValidationError:
             return queryset.none()
 

+ 36 - 0
netbox/ipam/lookups.py

@@ -100,6 +100,42 @@ class NetHost(Lookup):
         return 'HOST(%s) = %s' % (lhs, rhs), params
 
 
+class NetIn(Lookup):
+    lookup_name = 'net_in'
+
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        with_mask, without_mask = [], []
+        for address in rhs_params[0]:
+            if '/' in address:
+                with_mask.append(address)
+            else:
+                without_mask.append(address)
+
+        address_in_clause = self.create_in_clause('{} IN ('.format(lhs), len(with_mask))
+        host_in_clause = self.create_in_clause('HOST({}) IN ('.format(lhs), len(without_mask))
+
+        if with_mask and not without_mask:
+            return address_in_clause, with_mask
+        elif not with_mask and without_mask:
+            return host_in_clause, without_mask
+
+        in_clause = '({}) OR ({})'.format(address_in_clause, host_in_clause)
+        with_mask.extend(without_mask)
+        return in_clause, with_mask
+
+    @staticmethod
+    def create_in_clause(clause_part, max_size):
+        clause_elements = [clause_part]
+        for offset in range(0, max_size):
+            if offset > 0:
+                clause_elements.append(', ')
+            clause_elements.append('%s')
+        clause_elements.append(')')
+        return ''.join(clause_elements)
+
+
 class NetHostContained(Lookup):
     """
     Check for the host portion of an IP address without regard to its mask. This allows us to find e.g. 192.0.2.1/24

+ 17 - 11
netbox/ipam/tests/test_filters.py

@@ -337,16 +337,18 @@ class IPAddressTestCase(TestCase):
             IPAddress(family=4, address='10.0.0.2/24', vrf=vrfs[0], interface=interfaces[0], status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-b'),
             IPAddress(family=4, address='10.0.0.3/24', vrf=vrfs[1], interface=interfaces[1], status=IPADDRESS_STATUS_RESERVED, role=IPADDRESS_ROLE_VIP, dns_name='ipaddress-c'),
             IPAddress(family=4, address='10.0.0.4/24', vrf=vrfs[2], interface=interfaces[2], status=IPADDRESS_STATUS_DEPRECATED, role=IPADDRESS_ROLE_SECONDARY, dns_name='ipaddress-d'),
+            IPAddress(family=4, address='10.0.0.1/25', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None),
             IPAddress(family=6, address='2001:db8::1/64', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-a'),
             IPAddress(family=6, address='2001:db8::2/64', vrf=vrfs[0], interface=interfaces[3], status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-b'),
             IPAddress(family=6, address='2001:db8::3/64', vrf=vrfs[1], interface=interfaces[4], status=IPADDRESS_STATUS_RESERVED, role=IPADDRESS_ROLE_VIP, dns_name='ipaddress-c'),
             IPAddress(family=6, address='2001:db8::4/64', vrf=vrfs[2], interface=interfaces[5], status=IPADDRESS_STATUS_DEPRECATED, role=IPADDRESS_ROLE_SECONDARY, dns_name='ipaddress-d'),
+            IPAddress(family=6, address='2001:db8::1/65', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None),
         )
         IPAddress.objects.bulk_create(ipaddresses)
 
     def test_family(self):
         params = {'family': '6'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)
 
     def test_dns_name(self):
         params = {'dns_name': ['ipaddress-a', 'ipaddress-b']}
@@ -359,20 +361,24 @@ class IPAddressTestCase(TestCase):
 
     def test_parent(self):
         params = {'parent': '10.0.0.0/24'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)
         params = {'parent': '2001:db8::/64'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)
 
-    def filter_address(self):
+    def test_filter_address(self):
         # Check IPv4 and IPv6, with and without a mask
-        params = {'address': '10.0.0.1/24'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
-        params = {'address': '10.0.0.1'}
+        params = {'address': ['10.0.0.1/24']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
-        params = {'address': '2001:db8::1/64'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
-        params = {'address': '2001:db8::1'}
+        params = {'address': ['10.0.0.1']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'address': ['10.0.0.1/24', '10.0.0.1/25']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'address': ['2001:db8::1/64']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
+        params = {'address': ['2001:db8::1']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'address': ['2001:db8::1/64', '2001:db8::1/65']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
     def test_mask_length(self):
         params = {'mask_length': '24'}
@@ -411,7 +417,7 @@ class IPAddressTestCase(TestCase):
         params = {'assigned_to_interface': 'true'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
         params = {'assigned_to_interface': 'false'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
 
     def test_status(self):
         params = {'status': [PREFIX_STATUS_DEPRECATED, PREFIX_STATUS_RESERVED]}