Explorar el Código

Closes #20241: Record A & B terminations on cable changelog records (#20246)

Jeremy Stretch hace 5 meses
padre
commit
873372f61e
Se han modificado 3 ficheros con 122 adiciones y 55 borrados
  1. 116 49
      netbox/dcim/models/cables.py
  2. 4 4
      netbox/utilities/testing/api.py
  3. 2 2
      netbox/utilities/testing/views.py

+ 116 - 49
netbox/dcim/models/cables.py

@@ -18,6 +18,7 @@ from utilities.conversion import to_meters
 from utilities.exceptions import AbortRequest
 from utilities.fields import ColorField, GenericArrayForeignKey
 from utilities.querysets import RestrictedQuerySet
+from utilities.serialization import deserialize_object, serialize_object
 from wireless.models import WirelessLink
 from .device_components import FrontPort, RearPort, PathEndpoint
 
@@ -119,43 +120,61 @@ class Cable(PrimaryModel):
         pk = self.pk or self._pk
         return self.label or f'#{pk}'
 
-    @property
-    def a_terminations(self):
-        if hasattr(self, '_a_terminations'):
-            return self._a_terminations
+    def get_status_color(self):
+        return LinkStatusChoices.colors.get(self.status)
+
+    def _get_x_terminations(self, side):
+        """
+        Return the terminating objects for the given cable end (A or B).
+        """
+        if side not in (CableEndChoices.SIDE_A, CableEndChoices.SIDE_B):
+            raise ValueError(f"Unknown cable side: {side}")
+        attr = f'_{side.lower()}_terminations'
 
+        if hasattr(self, attr):
+            return getattr(self, attr)
         if not self.pk:
             return []
-
-        # Query self.terminations.all() to leverage cached results
         return [
-            ct.termination for ct in self.terminations.all() if ct.cable_end == CableEndChoices.SIDE_A
+            # Query self.terminations.all() to leverage cached results
+            ct.termination for ct in self.terminations.all() if ct.cable_end == side
         ]
 
-    @a_terminations.setter
-    def a_terminations(self, value):
-        if not self.pk or self.a_terminations != list(value):
+    def _set_x_terminations(self, side, value):
+        """
+        Set the terminating objects for the given cable end (A or B).
+        """
+        if side not in (CableEndChoices.SIDE_A, CableEndChoices.SIDE_B):
+            raise ValueError(f"Unknown cable side: {side}")
+        _attr = f'_{side.lower()}_terminations'
+
+        # If the provided value is a list of CableTermination IDs, resolve them
+        # to their corresponding termination objects.
+        if all(isinstance(item, int) for item in value):
+            value = [
+                ct.termination for ct in CableTermination.objects.filter(pk__in=value).prefetch_related('termination')
+            ]
+
+        if not self.pk or getattr(self, _attr, []) != list(value):
             self._terminations_modified = True
-        self._a_terminations = value
+
+        setattr(self, _attr, value)
 
     @property
-    def b_terminations(self):
-        if hasattr(self, '_b_terminations'):
-            return self._b_terminations
+    def a_terminations(self):
+        return self._get_x_terminations(CableEndChoices.SIDE_A)
 
-        if not self.pk:
-            return []
+    @a_terminations.setter
+    def a_terminations(self, value):
+        self._set_x_terminations(CableEndChoices.SIDE_A, value)
 
-        # Query self.terminations.all() to leverage cached results
-        return [
-            ct.termination for ct in self.terminations.all() if ct.cable_end == CableEndChoices.SIDE_B
-        ]
+    @property
+    def b_terminations(self):
+        return self._get_x_terminations(CableEndChoices.SIDE_B)
 
     @b_terminations.setter
     def b_terminations(self, value):
-        if not self.pk or self.b_terminations != list(value):
-            self._terminations_modified = True
-        self._b_terminations = value
+        self._set_x_terminations(CableEndChoices.SIDE_B, value)
 
     @property
     def color_name(self):
@@ -208,7 +227,7 @@ class Cable(PrimaryModel):
             for termination in self.b_terminations:
                 CableTermination(cable=self, cable_end='B', termination=termination).clean()
 
-    def save(self, *args, **kwargs):
+    def save(self, *args, force_insert=False, force_update=False, using=None, update_fields=None):
         _created = self.pk is None
 
         # Store the given length (if any) in meters for use in database ordering
@@ -221,39 +240,87 @@ class Cable(PrimaryModel):
         if self.length is None:
             self.length_unit = None
 
-        super().save(*args, **kwargs)
+        # If this is a new Cable, save it before attempting to create its CableTerminations
+        if self._state.adding:
+            super().save(*args, force_insert=True, using=using, update_fields=update_fields)
+            # Update the private PK used in __str__()
+            self._pk = self.pk
 
-        # Update the private pk used in __str__ in case this is a new object (i.e. just got its pk)
-        self._pk = self.pk
+        if self._terminations_modified:
+            self.update_terminations()
 
-        # Retrieve existing A/B terminations for the Cable
-        a_terminations = {ct.termination: ct for ct in self.terminations.filter(cable_end='A')}
-        b_terminations = {ct.termination: ct for ct in self.terminations.filter(cable_end='B')}
+        super().save(*args, force_update=True, using=using, update_fields=update_fields)
 
-        # Delete stale CableTerminations
-        if self._terminations_modified:
-            for termination, ct in a_terminations.items():
-                if termination.pk and termination not in self.a_terminations:
-                    ct.delete()
-            for termination, ct in b_terminations.items():
-                if termination.pk and termination not in self.b_terminations:
-                    ct.delete()
-
-        # Save new CableTerminations (if any)
-        if self._terminations_modified:
-            for termination in self.a_terminations:
-                if not termination.pk or termination not in a_terminations:
-                    CableTermination(cable=self, cable_end='A', termination=termination).save()
-            for termination in self.b_terminations:
-                if not termination.pk or termination not in b_terminations:
-                    CableTermination(cable=self, cable_end='B', termination=termination).save()
         try:
             trace_paths.send(Cable, instance=self, created=_created)
         except UnsupportedCablePath as e:
             raise AbortRequest(e)
 
-    def get_status_color(self):
-        return LinkStatusChoices.colors.get(self.status)
+    def serialize_object(self, exclude=None):
+        data = serialize_object(self, exclude=exclude or [])
+
+        # Add A & B terminations to the serialized data
+        a_terminations, b_terminations = self.get_terminations()
+        data['a_terminations'] = sorted([ct.pk for ct in a_terminations.values()])
+        data['b_terminations'] = sorted([ct.pk for ct in b_terminations.values()])
+
+        return data
+
+    @classmethod
+    def deserialize_object(cls, data, pk=None):
+        a_terminations = data.pop('a_terminations', [])
+        b_terminations = data.pop('b_terminations', [])
+
+        instance = deserialize_object(cls, data, pk=pk)
+
+        # Assign A & B termination objects to the Cable instance
+        queryset = CableTermination.objects.prefetch_related('termination')
+        instance.a_terminations = [
+            ct.termination for ct in queryset.filter(pk__in=a_terminations)
+        ]
+        instance.b_terminations = [
+            ct.termination for ct in queryset.filter(pk__in=b_terminations)
+        ]
+
+        return instance
+
+    def get_terminations(self):
+        """
+        Return two dictionaries mapping A & B side terminating objects to their corresponding CableTerminations
+        for this Cable.
+        """
+        a_terminations = {}
+        b_terminations = {}
+
+        for ct in CableTermination.objects.filter(cable=self).prefetch_related('termination'):
+            if ct.cable_end == CableEndChoices.SIDE_A:
+                a_terminations[ct.termination] = ct
+            else:
+                b_terminations[ct.termination] = ct
+
+        return a_terminations, b_terminations
+
+    def update_terminations(self):
+        """
+        Create/delete CableTerminations for this Cable to reflect its current state.
+        """
+        a_terminations, b_terminations = self.get_terminations()
+
+        # Delete any stale CableTerminations
+        for termination, ct in a_terminations.items():
+            if termination.pk and termination not in self.a_terminations:
+                ct.delete()
+        for termination, ct in b_terminations.items():
+            if termination.pk and termination not in self.b_terminations:
+                ct.delete()
+
+        # Save any new CableTerminations
+        for termination in self.a_terminations:
+            if not termination.pk or termination not in a_terminations:
+                CableTermination(cable=self, cable_end='A', termination=termination).save()
+        for termination in self.b_terminations:
+            if not termination.pk or termination not in b_terminations:
+                CableTermination(cable=self, cable_end='B', termination=termination).save()
 
 
 class CableTermination(ChangeLoggedModel):

+ 4 - 4
netbox/utilities/testing/api.py

@@ -247,9 +247,9 @@ class APIViewTestCases:
             if issubclass(self.model, ChangeLoggingMixin):
                 objectchange = ObjectChange.objects.get(
                     changed_object_type=ContentType.objects.get_for_model(instance),
-                    changed_object_id=instance.pk
+                    changed_object_id=instance.pk,
+                    action=ObjectChangeActionChoices.ACTION_CREATE,
                 )
-                self.assertEqual(objectchange.action, ObjectChangeActionChoices.ACTION_CREATE)
                 self.assertEqual(objectchange.message, data['changelog_message'])
 
         def test_bulk_create_objects(self):
@@ -298,11 +298,11 @@ class APIViewTestCases:
                 ]
                 objectchanges = ObjectChange.objects.filter(
                     changed_object_type=ContentType.objects.get_for_model(self.model),
-                    changed_object_id__in=id_list
+                    changed_object_id__in=id_list,
+                    action=ObjectChangeActionChoices.ACTION_CREATE,
                 )
                 self.assertEqual(len(objectchanges), len(self.create_data))
                 for oc in objectchanges:
-                    self.assertEqual(oc.action, ObjectChangeActionChoices.ACTION_CREATE)
                     self.assertEqual(oc.message, changelog_message)
 
     class UpdateObjectViewTestCase(APITestCase):

+ 2 - 2
netbox/utilities/testing/views.py

@@ -655,11 +655,11 @@ class ViewTestCases:
                 self.assertIsNotNone(request_id, "Unable to determine request ID from response")
                 objectchanges = ObjectChange.objects.filter(
                     changed_object_type=ContentType.objects.get_for_model(self.model),
-                    request_id=request_id
+                    request_id=request_id,
+                    action=ObjectChangeActionChoices.ACTION_CREATE,
                 )
                 self.assertEqual(len(objectchanges), len(self.csv_data) - 1)
                 for oc in objectchanges:
-                    self.assertEqual(oc.action, ObjectChangeActionChoices.ACTION_CREATE)
                     self.assertEqual(oc.message, data['changelog_message'])
 
         @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])