Răsfoiți Sursa

Fix some instances where RestrictedQuerySet is evaluated prematurely

Jeremy Stretch 5 ani în urmă
părinte
comite
95965d65c9

+ 13 - 6
netbox/dcim/views.py

@@ -5,7 +5,7 @@ from django.contrib import messages
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.core.paginator import EmptyPage, PageNotAnInteger
 from django.core.paginator import EmptyPage, PageNotAnInteger
 from django.db import transaction
 from django.db import transaction
-from django.db.models import Count, F
+from django.db.models import Count, F, Prefetch
 from django.forms import ModelMultipleChoiceField, MultipleHiddenInput, modelformset_factory
 from django.forms import ModelMultipleChoiceField, MultipleHiddenInput, modelformset_factory
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
 from django.urls import reverse
 from django.urls import reverse
@@ -16,7 +16,7 @@ from django.views.generic import View
 from circuits.models import Circuit
 from circuits.models import Circuit
 from extras.models import Graph
 from extras.models import Graph
 from extras.views import ObjectConfigContextView
 from extras.views import ObjectConfigContextView
-from ipam.models import Prefix, Service, VLAN
+from ipam.models import IPAddress, Prefix, Service, VLAN
 from ipam.tables import InterfaceIPAddressTable, InterfaceVLANTable
 from ipam.tables import InterfaceIPAddressTable, InterfaceVLANTable
 from secrets.models import Secret
 from secrets.models import Secret
 from utilities.forms import ConfirmationForm
 from utilities.forms import ConfirmationForm
@@ -517,6 +517,7 @@ class DeviceTypeView(ObjectView):
     def get(self, request, pk):
     def get(self, request, pk):
 
 
         devicetype = get_object_or_404(self.queryset, pk=pk)
         devicetype = get_object_or_404(self.queryset, pk=pk)
+        instance_count = Device.objects.restrict(request.user).filter(device_type=devicetype).count()
 
 
         # Component tables
         # Component tables
         consoleport_table = tables.ConsolePortTemplateTable(
         consoleport_table = tables.ConsolePortTemplateTable(
@@ -563,6 +564,7 @@ class DeviceTypeView(ObjectView):
 
 
         return render(request, 'dcim/devicetype.html', {
         return render(request, 'dcim/devicetype.html', {
             'devicetype': devicetype,
             'devicetype': devicetype,
+            'instance_count': instance_count,
             'consoleport_table': consoleport_table,
             'consoleport_table': consoleport_table,
             'consoleserverport_table': consoleserverport_table,
             'consoleserverport_table': consoleserverport_table,
             'powerport_table': powerport_table,
             'powerport_table': powerport_table,
@@ -987,8 +989,10 @@ class DeviceView(ObjectView):
 
 
         # Interfaces
         # Interfaces
         interfaces = device.vc_interfaces.restrict(request.user, 'view').filter(device=device).prefetch_related(
         interfaces = device.vc_interfaces.restrict(request.user, 'view').filter(device=device).prefetch_related(
+            Prefetch('ip_addresses', queryset=IPAddress.objects.restrict(request.user)),
+            Prefetch('member_interfaces', queryset=Interface.objects.restrict(request.user)),
             'lag', '_connected_interface__device', '_connected_circuittermination__circuit', 'cable',
             'lag', '_connected_interface__device', '_connected_circuittermination__circuit', 'cable',
-            'cable__termination_a', 'cable__termination_b', 'ip_addresses', 'tags'
+            'cable__termination_a', 'cable__termination_b', 'tags'
         )
         )
 
 
         # Front ports
         # Front ports
@@ -1438,7 +1442,7 @@ class InterfaceView(ObjectView):
         if interface.untagged_vlan is not None:
         if interface.untagged_vlan is not None:
             vlans.append(interface.untagged_vlan)
             vlans.append(interface.untagged_vlan)
             vlans[0].tagged = False
             vlans[0].tagged = False
-        for vlan in interface.tagged_vlans.prefetch_related('site', 'group', 'tenant', 'role'):
+        for vlan in interface.tagged_vlans.restrict(request.user).prefetch_related('site', 'group', 'tenant', 'role'):
             vlan.tagged = True
             vlan.tagged = True
             vlans.append(vlan)
             vlans.append(vlan)
         vlan_table = InterfaceVLANTable(
         vlan_table = InterfaceVLANTable(
@@ -2149,13 +2153,15 @@ class VirtualChassisListView(ObjectListView):
 
 
 
 
 class VirtualChassisView(ObjectView):
 class VirtualChassisView(ObjectView):
-    queryset = VirtualChassis.objects.prefetch_related('members')
+    queryset = VirtualChassis.objects.all()
 
 
     def get(self, request, pk):
     def get(self, request, pk):
         virtualchassis = get_object_or_404(self.queryset, pk=pk)
         virtualchassis = get_object_or_404(self.queryset, pk=pk)
+        members = Device.objects.restrict(request.user).filter(virtual_chassis=virtualchassis)
 
 
         return render(request, 'dcim/virtualchassis.html', {
         return render(request, 'dcim/virtualchassis.html', {
             'virtualchassis': virtualchassis,
             'virtualchassis': virtualchassis,
+            'members': members,
         })
         })
 
 
 
 
@@ -2389,8 +2395,9 @@ class PowerPanelView(ObjectView):
     def get(self, request, pk):
     def get(self, request, pk):
 
 
         powerpanel = get_object_or_404(self.queryset, pk=pk)
         powerpanel = get_object_or_404(self.queryset, pk=pk)
+        power_feeds = PowerFeed.objects.restrict(request.user).filter(power_panel=powerpanel).prefetch_related('rack')
         powerfeed_table = tables.PowerFeedTable(
         powerfeed_table = tables.PowerFeedTable(
-            data=PowerFeed.objects.filter(power_panel=powerpanel).prefetch_related('rack'),
+            data=power_feeds,
             orderable=False
             orderable=False
         )
         )
         powerfeed_table.exclude = ['power_panel']
         powerfeed_table.exclude = ['power_panel']

+ 16 - 1
netbox/extras/views.py

@@ -2,13 +2,15 @@ from django import template
 from django.conf import settings
 from django.conf import settings
 from django.contrib import messages
 from django.contrib import messages
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
-from django.db.models import Count, Q
+from django.db.models import Count, Prefetch, Q
 from django.http import Http404, HttpResponseForbidden
 from django.http import Http404, HttpResponseForbidden
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
 from django.utils.safestring import mark_safe
 from django.utils.safestring import mark_safe
 from django.views.generic import View
 from django.views.generic import View
 from django_tables2 import RequestConfig
 from django_tables2 import RequestConfig
 
 
+from dcim.models import DeviceRole, Platform, Region, Site
+from tenancy.models import Tenant, TenantGroup
 from utilities.forms import ConfirmationForm
 from utilities.forms import ConfirmationForm
 from utilities.paginator import EnhancedPaginator
 from utilities.paginator import EnhancedPaginator
 from utilities.utils import shallow_compare_dict
 from utilities.utils import shallow_compare_dict
@@ -16,6 +18,7 @@ from utilities.views import (
     BulkDeleteView, BulkEditView, BulkImportView, ObjectView, ObjectDeleteView, ObjectEditView, ObjectListView,
     BulkDeleteView, BulkEditView, BulkImportView, ObjectView, ObjectDeleteView, ObjectEditView, ObjectListView,
     ObjectPermissionRequiredMixin,
     ObjectPermissionRequiredMixin,
 )
 )
+from virtualization.models import Cluster, ClusterGroup
 from . import filters, forms, tables
 from . import filters, forms, tables
 from .models import ConfigContext, ImageAttachment, ObjectChange, ReportResult, Tag, TaggedItem
 from .models import ConfigContext, ImageAttachment, ObjectChange, ReportResult, Tag, TaggedItem
 from .reports import get_report, get_reports
 from .reports import get_report, get_reports
@@ -120,6 +123,18 @@ class ConfigContextView(ObjectView):
     queryset = ConfigContext.objects.all()
     queryset = ConfigContext.objects.all()
 
 
     def get(self, request, pk):
     def get(self, request, pk):
+        # Extend queryset to prefetch related objects
+        self.queryset = self.queryset.prefetch_related(
+            Prefetch('regions', queryset=Region.objects.restrict(request.user)),
+            Prefetch('sites', queryset=Site.objects.restrict(request.user)),
+            Prefetch('roles', queryset=DeviceRole.objects.restrict(request.user)),
+            Prefetch('platforms', queryset=Platform.objects.restrict(request.user)),
+            Prefetch('clusters', queryset=Cluster.objects.restrict(request.user)),
+            Prefetch('cluster_groups', queryset=ClusterGroup.objects.restrict(request.user)),
+            Prefetch('tenants', queryset=Tenant.objects.restrict(request.user)),
+            Prefetch('tenant_groups', queryset=TenantGroup.objects.restrict(request.user)),
+        )
+
         configcontext = get_object_or_404(self.queryset, pk=pk)
         configcontext = get_object_or_404(self.queryset, pk=pk)
 
 
         # Determine user's preferred output format
         # Determine user's preferred output format

+ 5 - 2
netbox/ipam/models.py

@@ -255,7 +255,7 @@ class Aggregate(ChangeLoggedModel, CustomFieldModel):
         """
         """
         Determine the prefix utilization of the aggregate and return it as a percentage.
         Determine the prefix utilization of the aggregate and return it as a percentage.
         """
         """
-        queryset = Prefix.objects.filter(prefix__net_contained_or_equal=str(self.prefix))
+        queryset = Prefix.objects.unrestricted().filter(prefix__net_contained_or_equal=str(self.prefix))
         child_prefixes = netaddr.IPSet([p.prefix for p in queryset])
         child_prefixes = netaddr.IPSet([p.prefix for p in queryset])
         return int(float(child_prefixes.size) / self.prefix.size * 100)
         return int(float(child_prefixes.size) / self.prefix.size * 100)
 
 
@@ -553,7 +553,10 @@ class Prefix(ChangeLoggedModel, CustomFieldModel):
         "container", calculate utilization based on child prefixes. For all others, count child IP addresses.
         "container", calculate utilization based on child prefixes. For all others, count child IP addresses.
         """
         """
         if self.status == PrefixStatusChoices.STATUS_CONTAINER:
         if self.status == PrefixStatusChoices.STATUS_CONTAINER:
-            queryset = Prefix.objects.filter(prefix__net_contained=str(self.prefix), vrf=self.vrf)
+            queryset = Prefix.objects.unrestricted().filter(
+                prefix__net_contained=str(self.prefix),
+                vrf=self.vrf
+            )
             child_prefixes = netaddr.IPSet([p.prefix for p in queryset])
             child_prefixes = netaddr.IPSet([p.prefix for p in queryset])
             return int(float(child_prefixes.size) / self.prefix.size * 100)
             return int(float(child_prefixes.size) / self.prefix.size * 100)
         else:
         else:

+ 8 - 4
netbox/ipam/views.py

@@ -1,6 +1,6 @@
 import netaddr
 import netaddr
 from django.conf import settings
 from django.conf import settings
-from django.db.models import Count
+from django.db.models import Count, Prefetch
 from django.db.models.expressions import RawSQL
 from django.db.models.expressions import RawSQL
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
 from django_tables2 import RequestConfig
 from django_tables2 import RequestConfig
@@ -108,10 +108,12 @@ class RIRListView(ObjectListView):
                 'deprecated': 0,
                 'deprecated': 0,
                 'available': 0,
                 'available': 0,
             }
             }
-            aggregate_list = Aggregate.objects.filter(prefix__family=family, rir=rir)
+            aggregate_list = Aggregate.objects.restrict(request.user).filter(prefix__family=family, rir=rir)
             for aggregate in aggregate_list:
             for aggregate in aggregate_list:
 
 
-                queryset = Prefix.objects.filter(prefix__net_contained_or_equal=str(aggregate.prefix))
+                queryset = Prefix.objects.restrict(request.user).filter(
+                    prefix__net_contained_or_equal=str(aggregate.prefix)
+                )
 
 
                 # Find all consumed space for each prefix status (we ignore containers for this purpose).
                 # Find all consumed space for each prefix status (we ignore containers for this purpose).
                 active_prefixes = netaddr.cidr_merge(
                 active_prefixes = netaddr.cidr_merge(
@@ -699,7 +701,9 @@ class VLANGroupVLANsView(ObjectView):
     def get(self, request, pk):
     def get(self, request, pk):
         vlan_group = get_object_or_404(self.queryset, pk=pk)
         vlan_group = get_object_or_404(self.queryset, pk=pk)
 
 
-        vlans = VLAN.objects.restrict(request.user, 'view').filter(group_id=pk)
+        vlans = VLAN.objects.restrict(request.user, 'view').filter(group_id=pk).prefetch_related(
+            Prefetch('prefixes', queryset=Prefix.objects.restrict(request.user))
+        )
         vlans = add_available_vlans(vlan_group, vlans)
         vlans = add_available_vlans(vlan_group, vlans)
 
 
         vlan_table = tables.VLANDetailTable(vlans)
         vlan_table = tables.VLANDetailTable(vlans)

+ 1 - 1
netbox/templates/dcim/devicetype.html

@@ -131,7 +131,7 @@
                 </tr>
                 </tr>
                 <tr>
                 <tr>
                     <td>Instances</td>
                     <td>Instances</td>
-                    <td><a href="{% url 'dcim:device_list' %}?device_type_id={{ devicetype.pk }}">{{ devicetype.instances.count }}</a></td>
+                    <td><a href="{% url 'dcim:device_list' %}?device_type_id={{ devicetype.pk }}">{{ instance_count }}</a></td>
                 </tr>
                 </tr>
             </table>
             </table>
         </div>
         </div>

+ 1 - 1
netbox/templates/dcim/virtualchassis.html

@@ -93,7 +93,7 @@
                     <th>Master</th>
                     <th>Master</th>
                     <th>Priority</th>
                     <th>Priority</th>
                 </tr>
                 </tr>
-                {% for vc_member in virtualchassis.members.all %}
+                {% for vc_member in members %}
                     <tr{% if vc_member == device %} class="info"{% endif %}>
                     <tr{% if vc_member == device %} class="info"{% endif %}>
                         <td>
                         <td>
                             <a href="{{ vc_member.get_absolute_url }}">{{ vc_member }}</a>
                             <a href="{{ vc_member.get_absolute_url }}">{{ vc_member }}</a>

+ 16 - 5
netbox/virtualization/views.py

@@ -1,13 +1,13 @@
 from django.contrib import messages
 from django.contrib import messages
 from django.db import transaction
 from django.db import transaction
-from django.db.models import Count
+from django.db.models import Count, Prefetch
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
 from django.urls import reverse
 from django.urls import reverse
 
 
 from dcim.models import Device
 from dcim.models import Device
 from dcim.tables import DeviceTable
 from dcim.tables import DeviceTable
 from extras.views import ObjectConfigContextView
 from extras.views import ObjectConfigContextView
-from ipam.models import Service
+from ipam.models import IPAddress, Service
 from ipam.tables import InterfaceIPAddressTable, InterfaceVLANTable
 from ipam.tables import InterfaceIPAddressTable, InterfaceVLANTable
 from utilities.views import (
 from utilities.views import (
     BulkComponentCreateView, BulkDeleteView, BulkEditView, BulkImportView, BulkRenameView, ComponentCreateView,
     BulkComponentCreateView, BulkDeleteView, BulkEditView, BulkImportView, BulkRenameView, ComponentCreateView,
@@ -88,6 +88,9 @@ class ClusterView(ObjectView):
     queryset = Cluster.objects.all()
     queryset = Cluster.objects.all()
 
 
     def get(self, request, pk):
     def get(self, request, pk):
+        self.queryset = self.queryset.prefetch_related(
+            Prefetch('virtual_machines', queryset=VirtualMachine.objects.restrict(request.user))
+        )
 
 
         cluster = get_object_or_404(self.queryset, pk=pk)
         cluster = get_object_or_404(self.queryset, pk=pk)
         devices = Device.objects.restrict(request.user, 'view').filter(cluster=cluster).prefetch_related(
         devices = Device.objects.restrict(request.user, 'view').filter(cluster=cluster).prefetch_related(
@@ -236,8 +239,16 @@ class VirtualMachineView(ObjectView):
     def get(self, request, pk):
     def get(self, request, pk):
 
 
         virtualmachine = get_object_or_404(self.queryset, pk=pk)
         virtualmachine = get_object_or_404(self.queryset, pk=pk)
-        interfaces = VMInterface.objects.restrict(request.user, 'view').filter(virtual_machine=virtualmachine)
-        services = Service.objects.restrict(request.user, 'view').filter(virtual_machine=virtualmachine)
+        interfaces = VMInterface.objects.restrict(request.user, 'view').filter(
+            virtual_machine=virtualmachine
+        ).prefetch_related(
+            Prefetch('ip_addresses', queryset=IPAddress.objects.restrict(request.user))
+        )
+        services = Service.objects.restrict(request.user, 'view').filter(
+            virtual_machine=virtualmachine
+        ).prefetch_related(
+            Prefetch('ipaddresses', queryset=IPAddress.objects.restrict(request.user))
+        )
 
 
         return render(request, 'virtualization/virtualmachine.html', {
         return render(request, 'virtualization/virtualmachine.html', {
             'virtualmachine': virtualmachine,
             'virtualmachine': virtualmachine,
@@ -315,7 +326,7 @@ class VMInterfaceView(ObjectView):
         if vminterface.untagged_vlan is not None:
         if vminterface.untagged_vlan is not None:
             vlans.append(vminterface.untagged_vlan)
             vlans.append(vminterface.untagged_vlan)
             vlans[0].tagged = False
             vlans[0].tagged = False
-        for vlan in vminterface.tagged_vlans.prefetch_related('site', 'group', 'tenant', 'role'):
+        for vlan in vminterface.tagged_vlans.restrict(request.user).prefetch_related('site', 'group', 'tenant', 'role'):
             vlan.tagged = True
             vlan.tagged = True
             vlans.append(vlan)
             vlans.append(vlan)
         vlan_table = InterfaceVLANTable(
         vlan_table = InterfaceVLANTable(