Explorar el Código

Closes #21263: Prefetch related objects after creating/updating objects via REST API (#21329)

* Closes #21263: Prefetch related objects after creating/updating objects via REST API

* Add comment re: ordering by PK
Jeremy Stretch hace 2 semanas
padre
commit
ad29cb2d66
Se han modificado 2 ficheros con 45 adiciones y 8 borrados
  1. 36 3
      netbox/netbox/api/viewsets/__init__.py
  2. 9 5
      netbox/netbox/api/viewsets/mixins.py

+ 36 - 3
netbox/netbox/api/viewsets/__init__.py

@@ -170,6 +170,28 @@ class NetBoxModelViewSet(
 
     # Creates
 
+    def create(self, request, *args, **kwargs):
+        serializer = self.get_serializer(data=request.data)
+        serializer.is_valid(raise_exception=True)
+        bulk_create = getattr(serializer, 'many', False)
+        self.perform_create(serializer)
+
+        # After creating the instance(s), re-initialize the serializer with a queryset
+        # to ensure related objects are prefetched.
+        if bulk_create:
+            instance_pks = [obj.pk for obj in serializer.instance]
+            # Order by PK to ensure that the ordering of objects in the response
+            # matches the ordering of those in the request.
+            qs = self.get_queryset().filter(pk__in=instance_pks).order_by('pk')
+        else:
+            qs = self.get_queryset().get(pk=serializer.instance.pk)
+
+        # Re-serialize the instance(s) with prefetched data
+        serializer = self.get_serializer(qs, many=bulk_create)
+
+        headers = self.get_success_headers(serializer.data)
+        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
+
     def perform_create(self, serializer):
         model = self.queryset.model
         logger = logging.getLogger(f'netbox.api.views.{self.__class__.__name__}')
@@ -186,9 +208,20 @@ class NetBoxModelViewSet(
     # Updates
 
     def update(self, request, *args, **kwargs):
-        # Hotwire get_object() to ensure we save a pre-change snapshot
-        self.get_object = self.get_object_with_snapshot
-        return super().update(request, *args, **kwargs)
+        partial = kwargs.pop('partial', False)
+        instance = self.get_object_with_snapshot()
+        serializer = self.get_serializer(instance, data=request.data, partial=partial)
+        serializer.is_valid(raise_exception=True)
+        self.perform_update(serializer)
+
+        # After updating the instance, re-initialize the serializer with a queryset
+        # to ensure related objects are prefetched.
+        qs = self.get_queryset().get(pk=serializer.instance.pk)
+
+        # Re-serialize the instance(s) with prefetched data
+        serializer = self.get_serializer(qs)
+
+        return Response(serializer.data)
 
     def perform_update(self, serializer):
         model = self.queryset.model

+ 9 - 5
netbox/netbox/api/viewsets/mixins.py

@@ -108,13 +108,17 @@ class BulkUpdateModelMixin:
             obj.pop('id'): obj for obj in request.data
         }
 
-        data = self.perform_bulk_update(qs, update_data, partial=partial)
+        object_pks = self.perform_bulk_update(qs, update_data, partial=partial)
 
-        return Response(data, status=status.HTTP_200_OK)
+        # Prefetch related objects for all updated instances
+        qs = self.get_queryset().filter(pk__in=object_pks)
+        serializer = self.get_serializer(qs, many=True)
+
+        return Response(serializer.data, status=status.HTTP_200_OK)
 
     def perform_bulk_update(self, objects, update_data, partial):
+        updated_pks = []
         with transaction.atomic(using=router.db_for_write(self.queryset.model)):
-            data_list = []
             for obj in objects:
                 data = update_data.get(obj.id)
                 if hasattr(obj, 'snapshot'):
@@ -122,9 +126,9 @@ class BulkUpdateModelMixin:
                 serializer = self.get_serializer(obj, data=data, partial=partial)
                 serializer.is_valid(raise_exception=True)
                 self.perform_update(serializer)
-                data_list.append(serializer.data)
+                updated_pks.append(obj.pk)
 
-            return data_list
+        return updated_pks
 
     def bulk_partial_update(self, request, *args, **kwargs):
         kwargs['partial'] = True