Bläddra i källkod

Fixes #22233: Fix site & location filtering for cables connecting circuit terminations

Jeremy Stretch 1 dag sedan
förälder
incheckning
d15d15b285

+ 35 - 0
netbox/dcim/migrations/0235_cabletermination_circuit_site_cache.py

@@ -0,0 +1,35 @@
+from django.db import migrations
+from django.db.models import OuterRef, Subquery
+
+
+def populate_circuit_termination_site_cache(apps, schema_editor):
+    """
+    Populate the cached _site and _location fields on CableTermination records whose
+    termination is a CircuitTermination. Earlier versions failed to cache these values,
+    causing site/location filters on cables to omit such terminations.
+    """
+    ContentType = apps.get_model('contenttypes', 'ContentType')
+    CableTermination = apps.get_model('dcim', 'CableTermination')
+    CircuitTermination = apps.get_model('circuits', 'CircuitTermination')
+
+    try:
+        ct = ContentType.objects.get_by_natural_key('circuits', 'circuittermination')
+    except ContentType.DoesNotExist:
+        return
+
+    circuit_terminations = CircuitTermination.objects.filter(pk=OuterRef('termination_id'))
+    CableTermination.objects.filter(termination_type=ct).update(
+        _site=Subquery(circuit_terminations.values('_site_id')[:1]),
+        _location=Subquery(circuit_terminations.values('_location_id')[:1]),
+    )
+
+
+class Migration(migrations.Migration):
+    dependencies = [
+        ('dcim', '0234_cablepath_nodes_index'),
+        ('circuits', '0057_default_ordering_indexes'),
+    ]
+
+    operations = [
+        migrations.RunPython(populate_circuit_termination_site_cache, migrations.RunPython.noop),
+    ]

+ 4 - 3
netbox/dcim/models/cables.py

@@ -690,9 +690,10 @@ class CableTermination(ChangeLoggedModel):
             self._location = self.termination.rack.location
             self._location = self.termination.rack.location
             self._site = self.termination.rack.site
             self._site = self.termination.rack.site
 
 
-        # Circuit terminations
-        elif getattr(self.termination, 'site', None):
-            self._site = self.termination.site
+        # Circuit terminations (which cache their own site/location)
+        elif self.termination._meta.label_lower == 'circuits.circuittermination':
+            self._site = self.termination._site
+            self._location = self.termination._location
     cache_related_objects.alters_data = True
     cache_related_objects.alters_data = True
 
 
     def to_objectchange(self, action):
     def to_objectchange(self, action):

+ 19 - 7
netbox/dcim/tests/test_filtersets.py

@@ -6814,6 +6814,13 @@ class CableTestCase(TestCase, ChangeLoggedFilterSetTests):
         circuit_type = CircuitType.objects.create(name='Circuit Type 1', slug='circuit-type-1')
         circuit_type = CircuitType.objects.create(name='Circuit Type 1', slug='circuit-type-1')
         circuit = Circuit.objects.create(cid='Circuit 1', provider=provider, type=circuit_type)
         circuit = Circuit.objects.create(cid='Circuit 1', provider=provider, type=circuit_type)
         circuit_termination = CircuitTermination.objects.create(circuit=circuit, term_side='A', termination=sites[0])
         circuit_termination = CircuitTermination.objects.create(circuit=circuit, term_side='A', termination=sites[0])
+        circuit2 = Circuit.objects.create(cid='Circuit 2', provider=provider, type=circuit_type)
+        circuit_termination_a = CircuitTermination.objects.create(
+            circuit=circuit2, term_side='A', termination=sites[0]
+        )
+        circuit_termination_z = CircuitTermination.objects.create(
+            circuit=circuit2, term_side='Z', termination=locations[0]
+        )
 
 
         # Cables
         # Cables
         cables = (
         cables = (
@@ -6920,6 +6927,11 @@ class CableTestCase(TestCase, ChangeLoggedFilterSetTests):
                 a_terminations=[circuit_termination],
                 a_terminations=[circuit_termination],
                 label='Cable 14'
                 label='Cable 14'
             ),
             ),
+            Cable(
+                a_terminations=[circuit_termination_a],
+                b_terminations=[circuit_termination_z],
+                label='Cable 15'
+            ),
         )
         )
         for cable in cables:
         for cable in cables:
             cable.save()
             cable.save()
@@ -6944,13 +6956,13 @@ class CableTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'type': [CableTypeChoices.TYPE_CAT3, CableTypeChoices.TYPE_CAT5E]}
         params = {'type': [CableTypeChoices.TYPE_CAT3, CableTypeChoices.TYPE_CAT5E]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
         params = {'type__empty': 'true'}
         params = {'type__empty': 'true'}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 9)
         params = {'type__empty': 'false'}
         params = {'type__empty': 'false'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
 
 
     def test_status(self):
     def test_status(self):
         params = {'status': [LinkStatusChoices.STATUS_CONNECTED]}
         params = {'status': [LinkStatusChoices.STATUS_CONNECTED]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 11)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 12)
         params = {'status': [LinkStatusChoices.STATUS_PLANNED]}
         params = {'status': [LinkStatusChoices.STATUS_PLANNED]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
 
 
@@ -6979,16 +6991,16 @@ class CableTestCase(TestCase, ChangeLoggedFilterSetTests):
     def test_location(self):
     def test_location(self):
         locations = Location.objects.all()[:2]
         locations = Location.objects.all()[:2]
         params = {'location_id': [locations[0].pk, locations[1].pk]}
         params = {'location_id': [locations[0].pk, locations[1].pk]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 11)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 12)
         params = {'location': [locations[0].name, locations[1].name]}
         params = {'location': [locations[0].name, locations[1].name]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 11)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 12)
 
 
     def test_site(self):
     def test_site(self):
         site = Site.objects.all()[:2]
         site = Site.objects.all()[:2]
         params = {'site_id': [site[0].pk, site[1].pk]}
         params = {'site_id': [site[0].pk, site[1].pk]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 11)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 13)
         params = {'site': [site[0].slug, site[1].slug]}
         params = {'site': [site[0].slug, site[1].slug]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 11)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 13)
 
 
     def test_tenant(self):
     def test_tenant(self):
         tenant = Tenant.objects.all()[:2]
         tenant = Tenant.objects.all()[:2]
@@ -7016,7 +7028,7 @@ class CableTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'unterminated': True}
         params = {'unterminated': True}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
         params = {'unterminated': False}
         params = {'unterminated': False}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 7)
 
 
     def test_consoleport(self):
     def test_consoleport(self):
         params = {'consoleport_id': [ConsolePort.objects.first().pk]}
         params = {'consoleport_id': [ConsolePort.objects.first().pk]}