Jelajahi Sumber

Closes #22375: Fix VLAN filter_interface_id performance: use UNION instead of OR across M2M joins (#22387)

Alex Houlton 1 bulan lalu
induk
melakukan
b905e99e63
2 mengubah file dengan 46 tambahan dan 9 penghapusan
  1. 6 8
      netbox/ipam/filtersets.py
  2. 40 1
      netbox/ipam/tests/test_filtersets.py

+ 6 - 8
netbox/ipam/filtersets.py

@@ -1142,19 +1142,17 @@ class VLANFilterSet(PrimaryModelFilterSet, TenancyFilterSet):
     def filter_interface_id(self, queryset, name, value):
         if value is None:
             return queryset.none()
-        return queryset.filter(
-            Q(interfaces_as_tagged=value) |
-            Q(interfaces_as_untagged=value)
-        ).distinct()
+        tagged = queryset.filter(interfaces_as_tagged=value)
+        untagged = queryset.filter(interfaces_as_untagged=value)
+        return queryset.filter(pk__in=tagged.union(untagged).values('pk'))
 
     @extend_schema_field(OpenApiTypes.INT)
     def filter_vminterface_id(self, queryset, name, value):
         if value is None:
             return queryset.none()
-        return queryset.filter(
-            Q(vminterfaces_as_tagged=value) |
-            Q(vminterfaces_as_untagged=value)
-        ).distinct()
+        tagged = queryset.filter(vminterfaces_as_tagged=value)
+        untagged = queryset.filter(vminterfaces_as_untagged=value)
+        return queryset.filter(pk__in=tagged.union(untagged).values('pk'))
 
 
 @register_filterset

+ 40 - 1
netbox/ipam/tests/test_filtersets.py

@@ -4,7 +4,7 @@ from django.test import TestCase
 from netaddr import IPNetwork
 
 from circuits.models import Provider
-from dcim.choices import InterfaceTypeChoices
+from dcim.choices import InterfaceModeChoices, InterfaceTypeChoices
 from dcim.models import Device, DeviceRole, DeviceType, Interface, Location, Manufacturer, Rack, Region, Site, SiteGroup
 from ipam.choices import *
 from ipam.filtersets import *
@@ -2206,11 +2206,50 @@ class VLANTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'interface_id': interface_id}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
 
+        # An interface untagged on one VLAN and tagged on a different VLAN should return both (UNION across paths)
+        vlans = self.queryset.all()[:2]
+        interface = Interface.objects.create(
+            device=Device.objects.first(),
+            name='Interface X',
+            type=InterfaceTypeChoices.TYPE_1GE_FIXED,
+            mode=InterfaceModeChoices.MODE_TAGGED,
+            untagged_vlan=vlans[0],
+        )
+        interface.tagged_vlans.add(vlans[1])
+        params = {'interface_id': interface.pk}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+        # A VLAN that is both untagged and tagged on the same interface should be returned only once (deduplication)
+        interface.tagged_vlans.add(vlans[0])
+        params = {'interface_id': interface.pk}
+        qs = self.filterset(params, self.queryset).qs
+        self.assertEqual(qs.count(), 2)
+        self.assertEqual(len(qs), len(set(qs.values_list('pk', flat=True))))
+
     def test_vminterface(self):
         vminterface_id = VMInterface.objects.first().pk
         params = {'vminterface_id': vminterface_id}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
 
+        # A VM interface untagged on one VLAN and tagged on a different VLAN should return both (UNION across paths)
+        vlans = self.queryset.all()[:2]
+        vminterface = VMInterface.objects.create(
+            virtual_machine=VirtualMachine.objects.first(),
+            name='VM Interface X',
+            mode=InterfaceModeChoices.MODE_TAGGED,
+            untagged_vlan=vlans[0],
+        )
+        vminterface.tagged_vlans.add(vlans[1])
+        params = {'vminterface_id': vminterface.pk}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+        # A VLAN that is both untagged and tagged on the same interface should be returned only once (deduplication)
+        vminterface.tagged_vlans.add(vlans[0])
+        params = {'vminterface_id': vminterface.pk}
+        qs = self.filterset(params, self.queryset).qs
+        self.assertEqual(qs.count(), 2)
+        self.assertEqual(len(qs), len(set(qs.values_list('pk', flat=True))))
+
     def test_qinq_role(self):
         params = {'qinq_role': [VLANQinQRoleChoices.ROLE_SERVICE, VLANQinQRoleChoices.ROLE_CUSTOMER]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)