Просмотр исходного кода

Closes #5549: Eliminate extraneous database queries when using brief API calls

Jeremy Stretch 5 лет назад
Родитель
Сommit
8ae3331d04

+ 1 - 0
netbox/circuits/api/views.py

@@ -65,3 +65,4 @@ class CircuitTerminationViewSet(PathEndpointMixin, ModelViewSet):
     )
     )
     serializer_class = serializers.CircuitTerminationSerializer
     serializer_class = serializers.CircuitTerminationSerializer
     filterset_class = filters.CircuitTerminationFilterSet
     filterset_class = filters.CircuitTerminationFilterSet
+    brief_prefetch_fields = ['circuit']

+ 11 - 0
netbox/dcim/api/views.py

@@ -258,6 +258,7 @@ class DeviceTypeViewSet(CustomFieldModelViewSet):
     )
     )
     serializer_class = serializers.DeviceTypeSerializer
     serializer_class = serializers.DeviceTypeSerializer
     filterset_class = filters.DeviceTypeFilterSet
     filterset_class = filters.DeviceTypeFilterSet
+    brief_prefetch_fields = ['manufacturer']
 
 
 
 
 #
 #
@@ -493,6 +494,7 @@ class ConsolePortViewSet(PathEndpointMixin, ModelViewSet):
     queryset = ConsolePort.objects.prefetch_related('device', '_path__destination', 'cable', '_cable_peer', 'tags')
     queryset = ConsolePort.objects.prefetch_related('device', '_path__destination', 'cable', '_cable_peer', 'tags')
     serializer_class = serializers.ConsolePortSerializer
     serializer_class = serializers.ConsolePortSerializer
     filterset_class = filters.ConsolePortFilterSet
     filterset_class = filters.ConsolePortFilterSet
+    brief_prefetch_fields = ['device']
 
 
 
 
 class ConsoleServerPortViewSet(PathEndpointMixin, ModelViewSet):
 class ConsoleServerPortViewSet(PathEndpointMixin, ModelViewSet):
@@ -501,18 +503,21 @@ class ConsoleServerPortViewSet(PathEndpointMixin, ModelViewSet):
     )
     )
     serializer_class = serializers.ConsoleServerPortSerializer
     serializer_class = serializers.ConsoleServerPortSerializer
     filterset_class = filters.ConsoleServerPortFilterSet
     filterset_class = filters.ConsoleServerPortFilterSet
+    brief_prefetch_fields = ['device']
 
 
 
 
 class PowerPortViewSet(PathEndpointMixin, ModelViewSet):
 class PowerPortViewSet(PathEndpointMixin, ModelViewSet):
     queryset = PowerPort.objects.prefetch_related('device', '_path__destination', 'cable', '_cable_peer', 'tags')
     queryset = PowerPort.objects.prefetch_related('device', '_path__destination', 'cable', '_cable_peer', 'tags')
     serializer_class = serializers.PowerPortSerializer
     serializer_class = serializers.PowerPortSerializer
     filterset_class = filters.PowerPortFilterSet
     filterset_class = filters.PowerPortFilterSet
+    brief_prefetch_fields = ['device']
 
 
 
 
 class PowerOutletViewSet(PathEndpointMixin, ModelViewSet):
 class PowerOutletViewSet(PathEndpointMixin, ModelViewSet):
     queryset = PowerOutlet.objects.prefetch_related('device', '_path__destination', 'cable', '_cable_peer', 'tags')
     queryset = PowerOutlet.objects.prefetch_related('device', '_path__destination', 'cable', '_cable_peer', 'tags')
     serializer_class = serializers.PowerOutletSerializer
     serializer_class = serializers.PowerOutletSerializer
     filterset_class = filters.PowerOutletFilterSet
     filterset_class = filters.PowerOutletFilterSet
+    brief_prefetch_fields = ['device']
 
 
 
 
 class InterfaceViewSet(PathEndpointMixin, ModelViewSet):
 class InterfaceViewSet(PathEndpointMixin, ModelViewSet):
@@ -521,30 +526,35 @@ class InterfaceViewSet(PathEndpointMixin, ModelViewSet):
     )
     )
     serializer_class = serializers.InterfaceSerializer
     serializer_class = serializers.InterfaceSerializer
     filterset_class = filters.InterfaceFilterSet
     filterset_class = filters.InterfaceFilterSet
+    brief_prefetch_fields = ['device']
 
 
 
 
 class FrontPortViewSet(PassThroughPortMixin, ModelViewSet):
 class FrontPortViewSet(PassThroughPortMixin, ModelViewSet):
     queryset = FrontPort.objects.prefetch_related('device__device_type__manufacturer', 'rear_port', 'cable', 'tags')
     queryset = FrontPort.objects.prefetch_related('device__device_type__manufacturer', 'rear_port', 'cable', 'tags')
     serializer_class = serializers.FrontPortSerializer
     serializer_class = serializers.FrontPortSerializer
     filterset_class = filters.FrontPortFilterSet
     filterset_class = filters.FrontPortFilterSet
+    brief_prefetch_fields = ['device']
 
 
 
 
 class RearPortViewSet(PassThroughPortMixin, ModelViewSet):
 class RearPortViewSet(PassThroughPortMixin, ModelViewSet):
     queryset = RearPort.objects.prefetch_related('device__device_type__manufacturer', 'cable', 'tags')
     queryset = RearPort.objects.prefetch_related('device__device_type__manufacturer', 'cable', 'tags')
     serializer_class = serializers.RearPortSerializer
     serializer_class = serializers.RearPortSerializer
     filterset_class = filters.RearPortFilterSet
     filterset_class = filters.RearPortFilterSet
+    brief_prefetch_fields = ['device']
 
 
 
 
 class DeviceBayViewSet(ModelViewSet):
 class DeviceBayViewSet(ModelViewSet):
     queryset = DeviceBay.objects.prefetch_related('installed_device').prefetch_related('tags')
     queryset = DeviceBay.objects.prefetch_related('installed_device').prefetch_related('tags')
     serializer_class = serializers.DeviceBaySerializer
     serializer_class = serializers.DeviceBaySerializer
     filterset_class = filters.DeviceBayFilterSet
     filterset_class = filters.DeviceBayFilterSet
+    brief_prefetch_fields = ['device']
 
 
 
 
 class InventoryItemViewSet(ModelViewSet):
 class InventoryItemViewSet(ModelViewSet):
     queryset = InventoryItem.objects.prefetch_related('device', 'manufacturer').prefetch_related('tags')
     queryset = InventoryItem.objects.prefetch_related('device', 'manufacturer').prefetch_related('tags')
     serializer_class = serializers.InventoryItemSerializer
     serializer_class = serializers.InventoryItemSerializer
     filterset_class = filters.InventoryItemFilterSet
     filterset_class = filters.InventoryItemFilterSet
+    brief_prefetch_fields = ['device']
 
 
 
 
 #
 #
@@ -600,6 +610,7 @@ class VirtualChassisViewSet(ModelViewSet):
     )
     )
     serializer_class = serializers.VirtualChassisSerializer
     serializer_class = serializers.VirtualChassisSerializer
     filterset_class = filters.VirtualChassisFilterSet
     filterset_class = filters.VirtualChassisFilterSet
+    brief_prefetch_fields = ['master']
 
 
 
 
 #
 #

+ 4 - 5
netbox/extras/api/views.py

@@ -39,7 +39,6 @@ class ConfigContextQuerySetMixin:
     Provides a get_queryset() method which deals with adding the config context
     Provides a get_queryset() method which deals with adding the config context
     data annotation or not.
     data annotation or not.
     """
     """
-
     def get_queryset(self):
     def get_queryset(self):
         """
         """
         Build the proper queryset based on the request context
         Build the proper queryset based on the request context
@@ -49,11 +48,11 @@ class ConfigContextQuerySetMixin:
 
 
         Else, return the queryset annotated with config context data
         Else, return the queryset annotated with config context data
         """
         """
-
+        queryset = super().get_queryset()
         request = self.get_serializer_context()['request']
         request = self.get_serializer_context()['request']
-        if request.query_params.get('brief') or 'config_context' in request.query_params.get('exclude', []):
-            return self.queryset
-        return self.queryset.annotate_config_context_data()
+        if self.brief or 'config_context' in request.query_params.get('exclude', []):
+            return queryset
+        return queryset.annotate_config_context_data()
 
 
 
 
 #
 #

+ 24 - 16
netbox/netbox/api/views.py

@@ -9,11 +9,11 @@ from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
 from django.db import transaction
 from django.db import transaction
 from django.db.models import ProtectedError
 from django.db.models import ProtectedError
 from django_rq.queues import get_connection
 from django_rq.queues import get_connection
-from rest_framework import mixins, status
+from rest_framework import status
 from rest_framework.response import Response
 from rest_framework.response import Response
 from rest_framework.reverse import reverse
 from rest_framework.reverse import reverse
 from rest_framework.views import APIView
 from rest_framework.views import APIView
-from rest_framework.viewsets import GenericViewSet
+from rest_framework.viewsets import ModelViewSet as ModelViewSet_
 from rq.worker import Worker
 from rq.worker import Worker
 
 
 from netbox.api import BulkOperationSerializer
 from netbox.api import BulkOperationSerializer
@@ -120,17 +120,13 @@ class BulkDestroyModelMixin:
 # Viewsets
 # Viewsets
 #
 #
 
 
-class ModelViewSet(mixins.CreateModelMixin,
-                   mixins.RetrieveModelMixin,
-                   mixins.UpdateModelMixin,
-                   mixins.DestroyModelMixin,
-                   mixins.ListModelMixin,
-                   BulkUpdateModelMixin,
-                   BulkDestroyModelMixin,
-                   GenericViewSet):
+class ModelViewSet(BulkUpdateModelMixin, BulkDestroyModelMixin, ModelViewSet_):
     """
     """
-    Accept either a single object or a list of objects to create.
+    Extend DRF's ModelViewSet to support bulk update and delete functions.
     """
     """
+    brief = False
+    brief_prefetch_fields = []
+
     def get_serializer(self, *args, **kwargs):
     def get_serializer(self, *args, **kwargs):
 
 
         # If a list of objects has been provided, initialize the serializer with many=True
         # If a list of objects has been provided, initialize the serializer with many=True
@@ -142,22 +138,34 @@ class ModelViewSet(mixins.CreateModelMixin,
     def get_serializer_class(self):
     def get_serializer_class(self):
         logger = logging.getLogger('netbox.api.views.ModelViewSet')
         logger = logging.getLogger('netbox.api.views.ModelViewSet')
 
 
-        # If 'brief' has been passed as a query param, find and return the nested serializer for this model, if one
-        # exists
-        request = self.get_serializer_context()['request']
-        if request.query_params.get('brief'):
+        # If using 'brief' mode, find and return the nested serializer for this model, if one exists
+        if self.brief:
             logger.debug("Request is for 'brief' format; initializing nested serializer")
             logger.debug("Request is for 'brief' format; initializing nested serializer")
             try:
             try:
                 serializer = get_serializer_for_model(self.queryset.model, prefix='Nested')
                 serializer = get_serializer_for_model(self.queryset.model, prefix='Nested')
                 logger.debug(f"Using serializer {serializer}")
                 logger.debug(f"Using serializer {serializer}")
                 return serializer
                 return serializer
             except SerializerNotFound:
             except SerializerNotFound:
-                pass
+                logger.debug(f"Nested serializer for {self.queryset.model} not found!")
 
 
         # Fall back to the hard-coded serializer class
         # Fall back to the hard-coded serializer class
         logger.debug(f"Using serializer {self.serializer_class}")
         logger.debug(f"Using serializer {self.serializer_class}")
         return self.serializer_class
         return self.serializer_class
 
 
+    def get_queryset(self):
+        # If using brief mode, clear all prefetches from the queryset and append only brief_prefetch_fields (if any)
+        if self.brief:
+            return super().get_queryset().prefetch_related(None).prefetch_related(*self.brief_prefetch_fields)
+
+        return super().get_queryset()
+
+    def initialize_request(self, request, *args, **kwargs):
+        # Check if brief=True has been passed
+        if request.method == 'GET' and request.GET.get('brief'):
+            self.brief = True
+
+        return super().initialize_request(request, *args, **kwargs)
+
     def initial(self, request, *args, **kwargs):
     def initial(self, request, *args, **kwargs):
         super().initial(request, *args, **kwargs)
         super().initial(request, *args, **kwargs)
 
 

+ 1 - 0
netbox/virtualization/api/views.py

@@ -84,3 +84,4 @@ class VMInterfaceViewSet(ModelViewSet):
     )
     )
     serializer_class = serializers.VMInterfaceSerializer
     serializer_class = serializers.VMInterfaceSerializer
     filterset_class = filters.VMInterfaceFilterSet
     filterset_class = filters.VMInterfaceFilterSet
+    brief_prefetch_fields = ['virtual_machine']