ソースを参照

feat(extras): Add range_contains ORM lookup

Introduce a generic lookup for ArrayField(RangeField) that matches rows
where a scalar value is contained by any range in the array
(e.g. VLANGroup.vid_ranges).
Replace the raw-SQL helper in the VLANGroup FilterSet (`contains_vid`)
with the ORM lookup for better maintainability.
Add tests for the lookup and the FilterSet behavior.

Closes #20497
Martin Hauser 4 ヶ月 前
コミット
33d4759871

+ 32 - 1
netbox/extras/lookups.py

@@ -1,9 +1,39 @@
+from django.contrib.postgres.fields import ArrayField
+from django.contrib.postgres.fields.ranges import RangeField
 from django.db.models import CharField, JSONField, Lookup
 from django.db.models.fields.json import KeyTextTransform
 
 from .fields import CachedValueField
 
 
+class RangeContains(Lookup):
+    """
+    Filter ArrayField(RangeField) columns where ANY element-range contains the scalar RHS.
+
+    Usage (ORM):
+        Model.objects.filter(<range_array_field>__range_contains=<scalar>)
+
+    Works with int4range[], int8range[], daterange[], tstzrange[], etc.
+    """
+
+    lookup_name = 'range_contains'
+
+    def as_sql(self, compiler, connection):
+        # Compile LHS (the array-of-ranges column/expression) and RHS (scalar)
+        lhs, lhs_params = self.process_lhs(compiler, connection)
+        rhs, rhs_params = self.process_rhs(compiler, connection)
+
+        # Guard: only allow ArrayField whose base_field is a PostgreSQL RangeField
+        field = getattr(self.lhs, 'output_field', None)
+        if not (isinstance(field, ArrayField) and isinstance(field.base_field, RangeField)):
+            raise TypeError('range_contains is only valid for ArrayField(RangeField) columns')
+
+        # Range-contains-element using EXISTS + UNNEST keeps the range on the LHS: r @> value
+        sql = f"EXISTS (SELECT 1 FROM unnest({lhs}) AS r WHERE r @> {rhs})"
+        params = lhs_params + rhs_params
+        return sql, params
+
+
 class Empty(Lookup):
     """
     Filter on whether a string is empty.
@@ -25,7 +55,7 @@ class JSONEmpty(Lookup):
 
     A key is considered empty if it is "", null, or does not exist.
     """
-    lookup_name = "empty"
+    lookup_name = 'empty'
 
     def as_sql(self, compiler, connection):
         # self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform)
@@ -69,6 +99,7 @@ class NetContainsOrEquals(Lookup):
         return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params
 
 
+ArrayField.register_lookup(RangeContains)
 CharField.register_lookup(Empty)
 JSONField.register_lookup(JSONEmpty)
 CachedValueField.register_lookup(NetHost)

+ 2 - 16
netbox/ipam/filtersets.py

@@ -908,7 +908,8 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
         method='filter_scope'
     )
     contains_vid = django_filters.NumberFilter(
-        method='filter_contains_vid'
+        field_name='vid_ranges',
+        lookup_expr='range_contains',
     )
 
     class Meta:
@@ -931,21 +932,6 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
             scope_id=value
         )
 
-    def filter_contains_vid(self, queryset, name, value):
-        """
-        Return all VLANGroups which contain the given VLAN ID.
-        """
-        table_name = VLANGroup._meta.db_table
-        # TODO: See if this can be optimized without compromising queryset integrity
-        # Expand VLAN ID ranges to query by integer
-        groups = VLANGroup.objects.raw(
-            f'SELECT id FROM {table_name}, unnest(vid_ranges) vid_range WHERE %s <@ vid_range',
-            params=(value,)
-        )
-        return queryset.filter(
-            pk__in=[g.id for g in groups]
-        )
-
 
 class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
     region_id = TreeNodeMultipleChoiceFilter(

+ 2 - 2
netbox/ipam/graphql/filters.py

@@ -19,7 +19,7 @@ from tenancy.graphql.filter_mixins import ContactFilterMixin, TenancyFilterMixin
 from virtualization.models import VMInterface
 
 if TYPE_CHECKING:
-    from netbox.graphql.filter_lookups import IntegerArrayLookup, IntegerLookup
+    from netbox.graphql.filter_lookups import IntegerLookup, IntegerRangeArrayLookup
     from circuits.graphql.filters import ProviderFilter
     from core.graphql.filters import ContentTypeFilter
     from dcim.graphql.filters import SiteFilter
@@ -340,7 +340,7 @@ class VLANFilter(TenancyFilterMixin, PrimaryModelFilterMixin):
 
 @strawberry_django.filter_type(models.VLANGroup, lookups=True)
 class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin):
-    vid_ranges: Annotated['IntegerArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
+    vid_ranges: Annotated['IntegerRangeArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
         strawberry_django.filter_field()
     )
 

+ 4 - 0
netbox/ipam/tests/test_filtersets.py

@@ -1723,6 +1723,10 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
         params = {'contains_vid': 1}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
+        params = {'contains_vid': 12}  # 11 is NOT in [1,11)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
+        params = {'contains_vid': 4095}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 0)
 
     def test_region(self):
         params = {'region': Region.objects.first().pk}

+ 66 - 0
netbox/ipam/tests/test_lookups.py

@@ -0,0 +1,66 @@
+from django.test import TestCase
+from django.db.backends.postgresql.psycopg_any import NumericRange
+from ipam.models import VLANGroup
+
+
+class VLANGroupRangeContainsLookupTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        # Two ranges: [1,11) and [20,31)
+        cls.g1 = VLANGroup.objects.create(
+            name='VlanGroup-A',
+            slug='VlanGroup-A',
+            vid_ranges=[NumericRange(1, 11), NumericRange(20, 31)],
+        )
+        # One range: [100,201)
+        cls.g2 = VLANGroup.objects.create(
+            name='VlanGroup-B',
+            slug='VlanGroup-B',
+            vid_ranges=[NumericRange(100, 201)],
+        )
+        cls.g_empty = VLANGroup.objects.create(
+            name='VlanGroup-empty',
+            slug='VlanGroup-empty',
+            vid_ranges=[],
+        )
+
+    def test_contains_value_in_first_range(self):
+        """
+        Tests whether a specific value is contained within the first range in a queried
+        set of VLANGroup objects.
+        """
+        names = list(
+            VLANGroup.objects.filter(vid_ranges__range_contains=10).values_list('name', flat=True).order_by('name')
+        )
+        self.assertEqual(names, ['VlanGroup-A'])
+
+    def test_contains_value_in_second_range(self):
+        """
+        Tests if a value exists in the second range of VLANGroup objects and
+        validates the result against the expected list of names.
+        """
+        names = list(
+            VLANGroup.objects.filter(vid_ranges__range_contains=25).values_list('name', flat=True).order_by('name')
+        )
+        self.assertEqual(names, ['VlanGroup-A'])
+
+    def test_upper_bound_is_exclusive(self):
+        """
+        Tests if the upper bound of the range is exclusive in the filter method.
+        """
+        # 11 is NOT in [1,11)
+        self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=11).exists())
+
+    def test_no_match_far_outside(self):
+        """
+        Tests that no VLANGroup contains a VID within a specified range far outside
+        common VID bounds and returns `False`.
+        """
+        self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=4095).exists())
+
+    def test_empty_array_never_matches(self):
+        """
+        Tests the behavior of VLANGroup objects when an empty array is used to match a
+        specific condition.
+        """
+        self.assertFalse(VLANGroup.objects.filter(pk=self.g_empty.pk, vid_ranges__range_contains=1).exists())

+ 28 - 0
netbox/netbox/graphql/filter_lookups.py

@@ -24,6 +24,7 @@ __all__ = (
     'FloatLookup',
     'IntegerArrayLookup',
     'IntegerLookup',
+    'IntegerRangeArrayLookup',
     'JSONFilter',
     'StringArrayLookup',
     'TreeNodeFilter',
@@ -217,3 +218,30 @@ class FloatArrayLookup(ArrayLookup[float]):
 @strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.')
 class StringArrayLookup(ArrayLookup[str]):
     pass
+
+
+@strawberry.input(one_of=True, description='Lookups for an ArrayField(RangeField). Only one may be set.')
+class RangeArrayValueLookup(Generic[T]):
+    """
+    class for Array field of Range fields lookups
+    """
+
+    contains: T | None = strawberry.field(
+        default=strawberry.UNSET, description='Return rows where any stored range contains this value.'
+    )
+
+    @strawberry_django.filter_field
+    def filter(self, info: Info, queryset: QuerySet, prefix: str = '') -> Tuple[QuerySet, Q]:
+        """
+        Map GraphQL: { <field>: { contains: <T> } } To Django ORM: <field>__range_contains=<T>
+        """
+        if self.contains is strawberry.UNSET or self.contains is None:
+            return queryset, Q()
+
+        # Build '<prefix>range_contains' so it works for nested paths too
+        return queryset, Q(**{f'{prefix}range_contains': self.contains})
+
+
+@strawberry.input(one_of=True, description='Lookups for an ArrayField(IntegerRangeField). Only one may be set.')
+class IntegerRangeArrayLookup(RangeArrayValueLookup[int]):
+    pass