Просмотр исходного кода

Filter muiltiple ipaddress terms

kobayashi 6 лет назад
Родитель
Сommit
2e9f21e222

+ 1 - 0
docs/release-notes/version-2.6.md

@@ -11,6 +11,7 @@
 * [#3187](https://github.com/netbox-community/netbox/issues/3187) - Add rack selection field to rack elevations
 * [#3393](https://github.com/netbox-community/netbox/issues/3393) - Paginate the circuits at the provider details view
 * [#3440](https://github.com/netbox-community/netbox/issues/3440) - Add total length to cable trace
+* [#3525](https://github.com/netbox-community/netbox/issues/3525) - Enable ipaddress filtering with multiple address terms
 * [#3623](https://github.com/netbox-community/netbox/issues/3623) - Add word expansion during interface creation
 * [#3668](https://github.com/netbox-community/netbox/issues/3668) - Search by DNS name when assigning IP address
 * [#3851](https://github.com/netbox-community/netbox/issues/3851) - Allow passing initial data to custom script forms

+ 8 - 2
netbox/ipam/fields.py

@@ -1,6 +1,6 @@
 from django.core.exceptions import ValidationError
 from django.db import models
-from netaddr import AddrFormatError, IPNetwork
+from netaddr import AddrFormatError, IPNetwork, IPAddress
 
 from . import lookups
 from .formfields import IPFormField
@@ -23,7 +23,10 @@ class BaseIPField(models.Field):
         if not value:
             return value
         try:
-            return IPNetwork(value)
+            if '/' in str(value):
+                return IPNetwork(value)
+            else:
+                return IPAddress(value)
         except AddrFormatError as e:
             raise ValidationError("Invalid IP address format: {}".format(value))
         except (TypeError, ValueError) as e:
@@ -32,6 +35,8 @@ class BaseIPField(models.Field):
     def get_prep_value(self, value):
         if not value:
             return None
+        if isinstance(value, list):
+            return [str(self.to_python(v)) for v in value]
         return str(self.to_python(value))
 
     def form_class(self):
@@ -90,5 +95,6 @@ IPAddressField.register_lookup(lookups.NetContainedOrEqual)
 IPAddressField.register_lookup(lookups.NetContains)
 IPAddressField.register_lookup(lookups.NetContainsOrEquals)
 IPAddressField.register_lookup(lookups.NetHost)
+IPAddressField.register_lookup(lookups.NetHostIn)
 IPAddressField.register_lookup(lookups.NetHostContained)
 IPAddressField.register_lookup(lookups.NetMaskLength)

+ 6 - 8
netbox/ipam/filters.py

@@ -7,7 +7,7 @@ from netaddr.core import AddrFormatError
 from dcim.models import Device, Interface, Region, Site
 from extras.filters import CustomFieldFilterSet, CreatedUpdatedFilterSet
 from tenancy.filtersets import TenancyFilterSet
-from utilities.filters import NameSlugSearchFilterSet, NumericInFilter, TagFilter, TreeNodeMultipleChoiceFilter
+from utilities.filters import NameSlugSearchFilterSet, NumericInFilter, TagFilter, TreeNodeMultipleChoiceFilter, MultiValueCharFilter
 from virtualization.models import VirtualMachine
 from .constants import *
 from .models import Aggregate, IPAddress, Prefix, RIR, Role, Service, VLAN, VLANGroup, VRF
@@ -284,7 +284,7 @@ class IPAddressFilter(TenancyFilterSet, CustomFieldFilterSet, CreatedUpdatedFilt
         method='search_by_parent',
         label='Parent prefix',
     )
-    address = django_filters.CharFilter(
+    address = MultiValueCharFilter(
         method='filter_address',
         label='Address',
     )
@@ -371,13 +371,11 @@ class IPAddressFilter(TenancyFilterSet, CustomFieldFilterSet, CreatedUpdatedFilt
             return queryset.none()
 
     def filter_address(self, queryset, name, value):
-        if not value.strip():
-            return queryset
         try:
-            # Match address and subnet mask
-            if '/' in value:
-                return queryset.filter(address=value)
-            return queryset.filter(address__net_host=value)
+            return queryset.filter(
+                Q(address__in=value) |
+                Q(address__net_host_in=value)
+            )
         except ValidationError:
             return queryset.none()
 

+ 19 - 0
netbox/ipam/lookups.py

@@ -100,6 +100,25 @@ class NetHost(Lookup):
         return 'HOST(%s) = %s' % (lhs, rhs), params
 
 
+class NetHostIn(Lookup):
+    lookup_name = 'net_host_in'
+
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        in_elements = ['HOST(%s) IN (' % lhs]
+        params = []
+        for offset in range(0, len(rhs_params[0])):
+            if offset > 0:
+                in_elements.append(', ')
+            params.extend(lhs_params)
+            sqls_params = rhs_params[0][offset]
+            in_elements.append(rhs)
+            params.append(sqls_params)
+        in_elements.append(')')
+        return ''.join(in_elements), params
+
+
 class NetHostContained(Lookup):
     """
     Check for the host portion of an IP address without regard to its mask. This allows us to find e.g. 192.0.2.1/24

+ 17 - 11
netbox/ipam/tests/test_filters.py

@@ -337,16 +337,18 @@ class IPAddressTestCase(TestCase):
             IPAddress(family=4, address='10.0.0.2/24', vrf=vrfs[0], interface=interfaces[0], status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-b'),
             IPAddress(family=4, address='10.0.0.3/24', vrf=vrfs[1], interface=interfaces[1], status=IPADDRESS_STATUS_RESERVED, role=IPADDRESS_ROLE_VIP, dns_name='ipaddress-c'),
             IPAddress(family=4, address='10.0.0.4/24', vrf=vrfs[2], interface=interfaces[2], status=IPADDRESS_STATUS_DEPRECATED, role=IPADDRESS_ROLE_SECONDARY, dns_name='ipaddress-d'),
+            IPAddress(family=4, address='10.0.0.1/25', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None),
             IPAddress(family=6, address='2001:db8::1/64', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-a'),
             IPAddress(family=6, address='2001:db8::2/64', vrf=vrfs[0], interface=interfaces[3], status=IPADDRESS_STATUS_ACTIVE, role=None, dns_name='ipaddress-b'),
             IPAddress(family=6, address='2001:db8::3/64', vrf=vrfs[1], interface=interfaces[4], status=IPADDRESS_STATUS_RESERVED, role=IPADDRESS_ROLE_VIP, dns_name='ipaddress-c'),
             IPAddress(family=6, address='2001:db8::4/64', vrf=vrfs[2], interface=interfaces[5], status=IPADDRESS_STATUS_DEPRECATED, role=IPADDRESS_ROLE_SECONDARY, dns_name='ipaddress-d'),
+            IPAddress(family=6, address='2001:db8::1/65', vrf=None, interface=None, status=IPADDRESS_STATUS_ACTIVE, role=None),
         )
         IPAddress.objects.bulk_create(ipaddresses)
 
     def test_family(self):
         params = {'family': '6'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)
 
     def test_dns_name(self):
         params = {'dns_name': ['ipaddress-a', 'ipaddress-b']}
@@ -359,20 +361,24 @@ class IPAddressTestCase(TestCase):
 
     def test_parent(self):
         params = {'parent': '10.0.0.0/24'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)
         params = {'parent': '2001:db8::/64'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 5)
 
-    def filter_address(self):
+    def test_filter_address(self):
         # Check IPv4 and IPv6, with and without a mask
-        params = {'address': '10.0.0.1/24'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
-        params = {'address': '10.0.0.1'}
+        params = {'address': ['10.0.0.1/24']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
-        params = {'address': '2001:db8::1/64'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
-        params = {'address': '2001:db8::1'}
+        params = {'address': ['10.0.0.1']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'address': ['10.0.0.1/24', '10.0.0.1/25']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'address': ['2001:db8::1/64']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
+        params = {'address': ['2001:db8::1']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'address': ['2001:db8::1/64', '2001:db8::1/65']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
     def test_mask_length(self):
         params = {'mask_length': '24'}
@@ -411,7 +417,7 @@ class IPAddressTestCase(TestCase):
         params = {'assigned_to_interface': 'true'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
         params = {'assigned_to_interface': 'false'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
 
     def test_status(self):
         params = {'status': [PREFIX_STATUS_DEPRECATED, PREFIX_STATUS_RESERVED]}