Przeglądaj źródła

feat(extras): Add range_contains ORM lookup

Introduce a generic lookup for ArrayField(RangeField) that matches rows
where a scalar value is contained by any range in the array
(e.g. VLANGroup.vid_ranges).
Replace the raw-SQL helper in the VLANGroup FilterSet (`contains_vid`)
with the ORM lookup for better maintainability.
Add tests for the lookup and the FilterSet behavior.

Closes #20497
Martin Hauser 4 miesięcy temu
rodzic
commit
33d4759871

+ 32 - 1
netbox/extras/lookups.py

@@ -1,9 +1,39 @@
+from django.contrib.postgres.fields import ArrayField
+from django.contrib.postgres.fields.ranges import RangeField
 from django.db.models import CharField, JSONField, Lookup
 from django.db.models import CharField, JSONField, Lookup
 from django.db.models.fields.json import KeyTextTransform
 from django.db.models.fields.json import KeyTextTransform
 
 
 from .fields import CachedValueField
 from .fields import CachedValueField
 
 
 
 
+class RangeContains(Lookup):
+    """
+    Filter ArrayField(RangeField) columns where ANY element-range contains the scalar RHS.
+
+    Usage (ORM):
+        Model.objects.filter(<range_array_field>__range_contains=<scalar>)
+
+    Works with int4range[], int8range[], daterange[], tstzrange[], etc.
+    """
+
+    lookup_name = 'range_contains'
+
+    def as_sql(self, compiler, connection):
+        # Compile LHS (the array-of-ranges column/expression) and RHS (scalar)
+        lhs, lhs_params = self.process_lhs(compiler, connection)
+        rhs, rhs_params = self.process_rhs(compiler, connection)
+
+        # Guard: only allow ArrayField whose base_field is a PostgreSQL RangeField
+        field = getattr(self.lhs, 'output_field', None)
+        if not (isinstance(field, ArrayField) and isinstance(field.base_field, RangeField)):
+            raise TypeError('range_contains is only valid for ArrayField(RangeField) columns')
+
+        # Range-contains-element using EXISTS + UNNEST keeps the range on the LHS: r @> value
+        sql = f"EXISTS (SELECT 1 FROM unnest({lhs}) AS r WHERE r @> {rhs})"
+        params = lhs_params + rhs_params
+        return sql, params
+
+
 class Empty(Lookup):
 class Empty(Lookup):
     """
     """
     Filter on whether a string is empty.
     Filter on whether a string is empty.
@@ -25,7 +55,7 @@ class JSONEmpty(Lookup):
 
 
     A key is considered empty if it is "", null, or does not exist.
     A key is considered empty if it is "", null, or does not exist.
     """
     """
-    lookup_name = "empty"
+    lookup_name = 'empty'
 
 
     def as_sql(self, compiler, connection):
     def as_sql(self, compiler, connection):
         # self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform)
         # self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform)
@@ -69,6 +99,7 @@ class NetContainsOrEquals(Lookup):
         return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params
         return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params
 
 
 
 
+ArrayField.register_lookup(RangeContains)
 CharField.register_lookup(Empty)
 CharField.register_lookup(Empty)
 JSONField.register_lookup(JSONEmpty)
 JSONField.register_lookup(JSONEmpty)
 CachedValueField.register_lookup(NetHost)
 CachedValueField.register_lookup(NetHost)

+ 2 - 16
netbox/ipam/filtersets.py

@@ -908,7 +908,8 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
         method='filter_scope'
         method='filter_scope'
     )
     )
     contains_vid = django_filters.NumberFilter(
     contains_vid = django_filters.NumberFilter(
-        method='filter_contains_vid'
+        field_name='vid_ranges',
+        lookup_expr='range_contains',
     )
     )
 
 
     class Meta:
     class Meta:
@@ -931,21 +932,6 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
             scope_id=value
             scope_id=value
         )
         )
 
 
-    def filter_contains_vid(self, queryset, name, value):
-        """
-        Return all VLANGroups which contain the given VLAN ID.
-        """
-        table_name = VLANGroup._meta.db_table
-        # TODO: See if this can be optimized without compromising queryset integrity
-        # Expand VLAN ID ranges to query by integer
-        groups = VLANGroup.objects.raw(
-            f'SELECT id FROM {table_name}, unnest(vid_ranges) vid_range WHERE %s <@ vid_range',
-            params=(value,)
-        )
-        return queryset.filter(
-            pk__in=[g.id for g in groups]
-        )
-
 
 
 class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
 class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
     region_id = TreeNodeMultipleChoiceFilter(
     region_id = TreeNodeMultipleChoiceFilter(

+ 2 - 2
netbox/ipam/graphql/filters.py

@@ -19,7 +19,7 @@ from tenancy.graphql.filter_mixins import ContactFilterMixin, TenancyFilterMixin
 from virtualization.models import VMInterface
 from virtualization.models import VMInterface
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from netbox.graphql.filter_lookups import IntegerArrayLookup, IntegerLookup
+    from netbox.graphql.filter_lookups import IntegerLookup, IntegerRangeArrayLookup
     from circuits.graphql.filters import ProviderFilter
     from circuits.graphql.filters import ProviderFilter
     from core.graphql.filters import ContentTypeFilter
     from core.graphql.filters import ContentTypeFilter
     from dcim.graphql.filters import SiteFilter
     from dcim.graphql.filters import SiteFilter
@@ -340,7 +340,7 @@ class VLANFilter(TenancyFilterMixin, PrimaryModelFilterMixin):
 
 
 @strawberry_django.filter_type(models.VLANGroup, lookups=True)
 @strawberry_django.filter_type(models.VLANGroup, lookups=True)
 class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin):
 class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin):
-    vid_ranges: Annotated['IntegerArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
+    vid_ranges: Annotated['IntegerRangeArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
         strawberry_django.filter_field()
         strawberry_django.filter_field()
     )
     )
 
 

+ 4 - 0
netbox/ipam/tests/test_filtersets.py

@@ -1723,6 +1723,10 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
         params = {'contains_vid': 1}
         params = {'contains_vid': 1}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
+        params = {'contains_vid': 12}  # 11 is NOT in [1,11)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
+        params = {'contains_vid': 4095}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 0)
 
 
     def test_region(self):
     def test_region(self):
         params = {'region': Region.objects.first().pk}
         params = {'region': Region.objects.first().pk}

+ 66 - 0
netbox/ipam/tests/test_lookups.py

@@ -0,0 +1,66 @@
+from django.test import TestCase
+from django.db.backends.postgresql.psycopg_any import NumericRange
+from ipam.models import VLANGroup
+
+
+class VLANGroupRangeContainsLookupTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        # Two ranges: [1,11) and [20,31)
+        cls.g1 = VLANGroup.objects.create(
+            name='VlanGroup-A',
+            slug='VlanGroup-A',
+            vid_ranges=[NumericRange(1, 11), NumericRange(20, 31)],
+        )
+        # One range: [100,201)
+        cls.g2 = VLANGroup.objects.create(
+            name='VlanGroup-B',
+            slug='VlanGroup-B',
+            vid_ranges=[NumericRange(100, 201)],
+        )
+        cls.g_empty = VLANGroup.objects.create(
+            name='VlanGroup-empty',
+            slug='VlanGroup-empty',
+            vid_ranges=[],
+        )
+
+    def test_contains_value_in_first_range(self):
+        """
+        Tests whether a specific value is contained within the first range in a queried
+        set of VLANGroup objects.
+        """
+        names = list(
+            VLANGroup.objects.filter(vid_ranges__range_contains=10).values_list('name', flat=True).order_by('name')
+        )
+        self.assertEqual(names, ['VlanGroup-A'])
+
+    def test_contains_value_in_second_range(self):
+        """
+        Tests if a value exists in the second range of VLANGroup objects and
+        validates the result against the expected list of names.
+        """
+        names = list(
+            VLANGroup.objects.filter(vid_ranges__range_contains=25).values_list('name', flat=True).order_by('name')
+        )
+        self.assertEqual(names, ['VlanGroup-A'])
+
+    def test_upper_bound_is_exclusive(self):
+        """
+        Tests if the upper bound of the range is exclusive in the filter method.
+        """
+        # 11 is NOT in [1,11)
+        self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=11).exists())
+
+    def test_no_match_far_outside(self):
+        """
+        Tests that no VLANGroup contains a VID within a specified range far outside
+        common VID bounds and returns `False`.
+        """
+        self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=4095).exists())
+
+    def test_empty_array_never_matches(self):
+        """
+        Tests the behavior of VLANGroup objects when an empty array is used to match a
+        specific condition.
+        """
+        self.assertFalse(VLANGroup.objects.filter(pk=self.g_empty.pk, vid_ranges__range_contains=1).exists())

+ 28 - 0
netbox/netbox/graphql/filter_lookups.py

@@ -24,6 +24,7 @@ __all__ = (
     'FloatLookup',
     'FloatLookup',
     'IntegerArrayLookup',
     'IntegerArrayLookup',
     'IntegerLookup',
     'IntegerLookup',
+    'IntegerRangeArrayLookup',
     'JSONFilter',
     'JSONFilter',
     'StringArrayLookup',
     'StringArrayLookup',
     'TreeNodeFilter',
     'TreeNodeFilter',
@@ -217,3 +218,30 @@ class FloatArrayLookup(ArrayLookup[float]):
 @strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.')
 @strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.')
 class StringArrayLookup(ArrayLookup[str]):
 class StringArrayLookup(ArrayLookup[str]):
     pass
     pass
+
+
+@strawberry.input(one_of=True, description='Lookups for an ArrayField(RangeField). Only one may be set.')
+class RangeArrayValueLookup(Generic[T]):
+    """
+    class for Array field of Range fields lookups
+    """
+
+    contains: T | None = strawberry.field(
+        default=strawberry.UNSET, description='Return rows where any stored range contains this value.'
+    )
+
+    @strawberry_django.filter_field
+    def filter(self, info: Info, queryset: QuerySet, prefix: str = '') -> Tuple[QuerySet, Q]:
+        """
+        Map GraphQL: { <field>: { contains: <T> } } To Django ORM: <field>__range_contains=<T>
+        """
+        if self.contains is strawberry.UNSET or self.contains is None:
+            return queryset, Q()
+
+        # Build '<prefix>range_contains' so it works for nested paths too
+        return queryset, Q(**{f'{prefix}range_contains': self.contains})
+
+
+@strawberry.input(one_of=True, description='Lookups for an ArrayField(IntegerRangeField). Only one may be set.')
+class IntegerRangeArrayLookup(RangeArrayValueLookup[int]):
+    pass