Jelajahi Sumber

Fixes #20541: Enhance filter methods with dynamic prefixing (#20579)

Martin Hauser 3 bulan lalu
induk
melakukan
639bc4462b
2 mengubah file dengan 250 tambahan dan 18 penghapusan
  1. 49 18
      netbox/ipam/graphql/filters.py
  2. 201 0
      netbox/ipam/tests/test_api.py

+ 49 - 18
netbox/ipam/graphql/filters.py

@@ -79,12 +79,28 @@ class ASNRangeFilter(TenancyFilterMixin, OrganizationalModelFilterMixin):
 
 @strawberry_django.filter_type(models.Aggregate, lookups=True)
 class AggregateFilter(ContactFilterMixin, TenancyFilterMixin, PrimaryModelFilterMixin):
-    prefix: Annotated['PrefixFilter', strawberry.lazy('ipam.graphql.filters')] | None = strawberry_django.filter_field()
-    prefix_id: ID | None = strawberry_django.filter_field()
+    prefix: FilterLookup[str] | None = strawberry_django.filter_field()
     rir: Annotated['RIRFilter', strawberry.lazy('ipam.graphql.filters')] | None = strawberry_django.filter_field()
     rir_id: ID | None = strawberry_django.filter_field()
     date_added: DateFilterLookup[date] | None = strawberry_django.filter_field()
 
+    @strawberry_django.filter_field()
+    def contains(self, value: list[str], prefix) -> Q:
+        """
+        Return aggregates whose `prefix` contains any of the supplied networks.
+        Mirrors PrefixFilter.contains but operates on the Aggregate.prefix field itself.
+        """
+        if not value:
+            return Q()
+        q = Q()
+        for subnet in value:
+            try:
+                query = str(netaddr.IPNetwork(subnet.strip()).cidr)
+            except (AddrFormatError, ValueError):
+                continue
+            q |= Q(**{f"{prefix}prefix__net_contains": query})
+        return q
+
 
 @strawberry_django.filter_type(models.FHRPGroup, lookups=True)
 class FHRPGroupFilter(PrimaryModelFilterMixin):
@@ -119,28 +135,28 @@ class FHRPGroupAssignmentFilter(BaseObjectTypeFilterMixin, ChangeLogFilterMixin)
     )
 
     @strawberry_django.filter_field()
-    def device_id(self, queryset, value: list[str], prefix) -> Q:
-        return self.filter_device('id', value)
+    def device_id(self, value: list[str], prefix) -> Q:
+        return self.filter_device('id', value, prefix)
 
     @strawberry_django.filter_field()
     def device(self, value: list[str], prefix) -> Q:
-        return self.filter_device('name', value)
+        return self.filter_device('name', value, prefix)
 
     @strawberry_django.filter_field()
     def virtual_machine_id(self, value: list[str], prefix) -> Q:
-        return Q(interface_id__in=VMInterface.objects.filter(virtual_machine_id__in=value))
+        return Q(**{f"{prefix}interface_id__in": VMInterface.objects.filter(virtual_machine_id__in=value)})
 
     @strawberry_django.filter_field()
     def virtual_machine(self, value: list[str], prefix) -> Q:
-        return Q(interface_id__in=VMInterface.objects.filter(virtual_machine__name__in=value))
+        return Q(**{f"{prefix}interface_id__in": VMInterface.objects.filter(virtual_machine__name__in=value)})
 
-    def filter_device(self, field, value) -> Q:
+    def filter_device(self, field, value, prefix) -> Q:
         """Helper to standardize logic for device and device_id filters"""
         devices = Device.objects.filter(**{f'{field}__in': value})
         interface_ids = []
         for device in devices:
             interface_ids.extend(device.vc_interfaces().values_list('id', flat=True))
-        return Q(interface_id__in=interface_ids)
+        return Q(**{f"{prefix}interface_id__in": interface_ids})
 
 
 @strawberry_django.filter_type(models.IPAddress, lookups=True)
@@ -180,9 +196,9 @@ class IPAddressFilter(ContactFilterMixin, TenancyFilterMixin, PrimaryModelFilter
         for subnet in value:
             try:
                 query = str(netaddr.IPNetwork(subnet.strip()).cidr)
-                q |= Q(address__net_host_contained=query)
             except (AddrFormatError, ValueError):
-                return Q()
+                continue
+            q |= Q(**{f"{prefix}address__net_host_contained": query})
         return q
 
     @strawberry_django.filter_field()
@@ -217,9 +233,14 @@ class IPRangeFilter(ContactFilterMixin, TenancyFilterMixin, PrimaryModelFilterMi
         for subnet in value:
             try:
                 query = str(netaddr.IPNetwork(subnet.strip()).cidr)
-                q |= Q(start_address__net_host_contained=query, end_address__net_host_contained=query)
             except (AddrFormatError, ValueError):
-                return Q()
+                continue
+            q |= Q(
+                **{
+                    f"{prefix}start_address__net_host_contained": query,
+                    f"{prefix}end_address__net_host_contained": query,
+                }
+            )
         return q
 
     @strawberry_django.filter_field()
@@ -228,10 +249,17 @@ class IPRangeFilter(ContactFilterMixin, TenancyFilterMixin, PrimaryModelFilterMi
             return Q()
         q = Q()
         for subnet in value:
-            net = netaddr.IPNetwork(subnet.strip())
+            try:
+                net = netaddr.IPNetwork(subnet.strip())
+                query_start = str(netaddr.IPAddress(net.first))
+                query_end = str(netaddr.IPAddress(net.last))
+            except (AddrFormatError, ValueError):
+                continue
             q |= Q(
-                start_address__host__inet__lte=str(netaddr.IPAddress(net.first)),
-                end_address__host__inet__gte=str(netaddr.IPAddress(net.last)),
+                **{
+                    f"{prefix}start_address__host__inet__lte": query_start,
+                    f"{prefix}end_address__host__inet__gte": query_end,
+                }
             )
         return q
 
@@ -257,8 +285,11 @@ class PrefixFilter(ContactFilterMixin, ScopedFilterMixin, TenancyFilterMixin, Pr
             return Q()
         q = Q()
         for subnet in value:
-            query = str(netaddr.IPNetwork(subnet.strip()).cidr)
-            q |= Q(prefix__net_contains=query)
+            try:
+                query = str(netaddr.IPNetwork(subnet.strip()).cidr)
+            except (AddrFormatError, ValueError):
+                continue
+            q |= Q(**{f"{prefix}prefix__net_contains": query})
         return q
 
 

+ 201 - 0
netbox/ipam/tests/test_api.py

@@ -323,6 +323,55 @@ class AggregateTest(APIViewTestCases.APIViewTestCase):
             },
         ]
 
+    @tag('regression')
+    def test_graphql_aggregate_prefix_exact(self):
+        """
+        Test case to verify aggregate prefix equality via field lookup in GraphQL API.
+        """
+
+        self.add_permissions('ipam.view_aggregate', 'ipam.view_rir')
+
+        rir = RIR.objects.create(name='RFC6598', slug='rfc6598', is_private=True)
+        aggregate1 = Aggregate.objects.create(prefix='100.64.0.0/10', rir=rir)
+        Aggregate.objects.create(prefix='203.0.113.0/24', rir=rir)
+
+        url = reverse('graphql')
+        query = """{
+            aggregate_list(filters: { prefix: { exact: "100.64.0.0/10" } }) { prefix }
+        }"""
+        response = self.client.post(url, data={'query': query}, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+        data = response.json()
+        self.assertNotIn('errors', data)
+
+        prefixes = {row['prefix'] for row in data['data']['aggregate_list']}
+        self.assertIn(str(aggregate1.prefix), prefixes)
+
+    @tag('regression')
+    def test_graphql_aggregate_contains_skips_invalid(self):
+        """
+        Test the GraphQL API Aggregate `contains` filter skips invalid input.
+        """
+
+        self.add_permissions('ipam.view_aggregate', 'ipam.view_rir')
+
+        rir = RIR.objects.create(name='RIR 3', slug='rir-3', is_private=False)
+        aggregate1 = Aggregate.objects.create(prefix='100.64.0.0/10', rir=rir)
+        Aggregate.objects.create(prefix='203.0.113.0/24', rir=rir)
+
+        url = reverse('graphql')
+        query = """{
+            aggregate_list(filters: { contains: ["100.64.16.0/24", "not-a-cidr", ""] }) { prefix }
+        }"""
+        response = self.client.post(url, data={'query': query}, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+        data = response.json()
+        self.assertNotIn('errors', data)
+
+        prefixes = {row['prefix'] for row in data['data']['aggregate_list']}
+        self.assertIn(str(aggregate1.prefix), prefixes)
+        # No exception occurred; invalid entries were ignored
+
 
 class RoleTest(APIViewTestCases.APIViewTestCase):
     model = Role
@@ -546,6 +595,30 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(len(response.data), 8)
 
+    @tag('regression')
+    def test_graphql_tenant_prefixes_contains_nested_skips_invalid(self):
+        """
+        Test the GraphQL API Tenant nested Prefix `contains` filter skips invalid input.
+        """
+
+        self.add_permissions('ipam.view_prefix', 'ipam.view_vrf', 'tenancy.view_tenant')
+
+        tenant = Tenant.objects.create(name='Tenant 1', slug='tenant-1')
+        vrf = VRF.objects.create(name='Test VRF 1', rd='64512:1')
+        Prefix.objects.create(prefix='10.20.0.0/16', vrf=vrf, tenant=tenant)
+        Prefix.objects.create(prefix='198.51.100.0/24', vrf=vrf)  # non-tenant
+
+        url = reverse('graphql')
+        query = """{
+            tenant_list(filters: { prefixes: { contains: ["10.20.1.0/24", "not-a-cidr"] } }) { id }
+        }"""
+        response = self.client.post(url, data={'query': query}, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+        data = response.json()
+        self.assertNotIn('errors', data)
+
+        self.assertTrue(data['data']['tenant_list'])  # tenant returned
+
 
 class IPRangeTest(APIViewTestCases.APIViewTestCase):
     model = IPRange
@@ -645,6 +718,65 @@ class IPRangeTest(APIViewTestCases.APIViewTestCase):
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(len(response.data), 8)
 
+    @tag('regression')
+    def test_graphql_tenant_ip_ranges_parent_nested_skips_invalid(self):
+        """
+        Test the GraphQL API Tenant nested IP Range `parent` filter skips invalid input.
+        """
+
+        self.add_permissions('tenancy.view_tenant', 'ipam.view_iprange', 'ipam.view_vrf')
+
+        tenant = Tenant.objects.create(name='Tenant 1', slug='tenant-1')
+        vrf = VRF.objects.create(name='Test VRF 1', rd='64512:1')
+        IPRange.objects.create(
+            start_address=IPNetwork('10.30.0.1/24'), end_address=IPNetwork('10.30.0.255/24'), vrf=vrf, tenant=tenant
+        )
+        IPRange.objects.create(
+            start_address=IPNetwork('10.31.0.1/24'), end_address=IPNetwork('10.31.0.255/24'), vrf=vrf, tenant=tenant
+        )
+
+        url = reverse('graphql')
+        query = """{
+            tenant_list(filters: {
+                name: { exact: "Tenant 1" }
+                ip_ranges: { parent: ["10.30.0.0/24", "bogus"] }
+            }) { id }
+        }"""
+        response = self.client.post(url, data={'query': query}, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+        data = response.json()
+        self.assertNotIn('errors', data)
+        self.assertTrue(data['data']['tenant_list'])  # tenant returned
+        # No exception occurred; invalid entries were ignored
+
+    @tag('regression')
+    def test_graphql_tenant_ip_ranges_contains_nested_skips_invalid(self):
+        """
+        Test the GraphQL API Tenant nested IP Range `contains` filter skips invalid input.
+        """
+
+        self.add_permissions('tenancy.view_tenant', 'ipam.view_iprange', 'ipam.view_vrf')
+
+        tenant = Tenant.objects.create(name='Tenant 2', slug='tenant-2')
+        vrf = VRF.objects.create(name='Test VRF 1', rd='64512:2')
+        IPRange.objects.create(
+            start_address=IPNetwork('10.40.0.1/24'), end_address=IPNetwork('10.40.0.255/24'), vrf=vrf, tenant=tenant
+        )
+
+        url = reverse('graphql')
+        query = """{
+            tenant_list(filters: {
+                name: { exact: "Tenant 2" }
+                ip_ranges: { contains: ["10.40.0.128/25", "###"] }
+            }) { id }
+        }"""
+        response = self.client.post(url, data={'query': query}, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+        data = response.json()
+        self.assertNotIn('errors', data)
+        self.assertTrue(data['data']['tenant_list'])  # tenant returned
+        # No exception occurred; invalid entries were ignored
+
 
 class IPAddressTest(APIViewTestCases.APIViewTestCase):
     model = IPAddress
@@ -731,6 +863,75 @@ class IPAddressTest(APIViewTestCases.APIViewTestCase):
         response = self.client.patch(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
 
+    @tag('regression')
+    def test_graphql_device_primary_ip4_assigned_nested(self):
+        """
+        Test the GraphQL API Device nested IP Address `primary_ip4` filter.
+        """
+
+        self.add_permissions('dcim.view_device', 'dcim.view_interface', 'ipam.view_ipaddress')
+
+        site = Site.objects.create(name='Site 1')
+        manufacturer = Manufacturer.objects.create(name='Manufacturer 1')
+        device_type = DeviceType.objects.create(model='Device Type 1', manufacturer=manufacturer)
+        role = DeviceRole.objects.create(name='Switch')
+
+        device1 = Device.objects.create(name='Device 1', site=site, device_type=device_type, role=role, status='active')
+        interface1 = Interface.objects.create(name='Interface 1', device=device1, type='1000baset')
+        ip1 = IPAddress.objects.create(address='10.0.0.1/24')
+        ip1.assigned_object = interface1
+        ip1.save()
+        device1.primary_ip4 = ip1
+        device1.save()
+
+        device2 = Device.objects.create(name='Device 2', site=site, device_type=device_type, role=role, status='active')
+
+        url = reverse('graphql')
+        query = """{
+            device_list(filters: { primary_ip4: { assigned: true } }) { id name }
+        }"""
+        response = self.client.post(url, data={'query': query}, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+        data = response.json()
+        self.assertNotIn('errors', data)
+
+        ids = {row['id'] for row in data['data']['device_list']}
+        self.assertIn(str(device1.pk), ids)
+        self.assertNotIn(str(device2.pk), ids)
+
+    @tag('regression')
+    def test_graphql_device_primary_ip4_parent_nested_skips_invalid(self):
+        """
+        Test the GraphQL API Device nested IP Address `parent` filter skips invalid input.
+        """
+
+        self.add_permissions('dcim.view_device', 'dcim.view_interface', 'ipam.view_ipaddress')
+
+        site = Site.objects.create(name='Site 1')
+        manufacturer = Manufacturer.objects.create(name='Manufacturer 1')
+        device_type = DeviceType.objects.create(model='Device Type 1', manufacturer=manufacturer)
+        role = DeviceRole.objects.create(name='Switch')
+
+        device1 = Device.objects.create(name='Device 1', site=site, device_type=device_type, role=role, status='active')
+        interface1 = Interface.objects.create(name='Interface 1', device=device1, type='1000baset')
+        ip1 = IPAddress.objects.create(address='192.0.2.10/24')
+        ip1.assigned_object = interface1
+        ip1.save()
+        device1.primary_ip4 = ip1
+        device1.save()
+
+        url = reverse('graphql')
+        query = """{
+            device_list(filters: { primary_ip4: { parent: ["192.0.2.0/24", "bad-cidr"] } }) { id }
+        }"""
+        response = self.client.post(url, data={'query': query}, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+        data = response.json()
+        self.assertNotIn('errors', data)
+
+        ids = {row['id'] for row in data['data']['device_list']}
+        self.assertIn(str(device1.pk), ids)
+
 
 class FHRPGroupTest(APIViewTestCases.APIViewTestCase):
     model = FHRPGroup