Browse Source

Call restrict() when retrieving related Graphs

Jeremy Stretch 5 years ago
parent
commit
0dbe248df8

+ 2 - 2
netbox/circuits/api/views.py

@@ -28,8 +28,8 @@ class ProviderViewSet(CustomFieldModelViewSet):
         """
         """
         A convenience method for rendering graphs for a particular provider.
         A convenience method for rendering graphs for a particular provider.
         """
         """
-        provider = get_object_or_404(Provider, pk=pk)
-        queryset = Graph.objects.filter(type__model='provider')
+        provider = get_object_or_404(self.queryset, pk=pk)
+        queryset = Graph.objects.restrict(request.user).filter(type__model='provider')
         serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': provider})
         serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': provider})
         return Response(serializer.data)
         return Response(serializer.data)
 
 

+ 1 - 1
netbox/circuits/tests/test_api.py

@@ -49,7 +49,7 @@ class ProviderTest(APIViewTestCases.APIViewTestCase):
         """
         """
         Test retrieval of Graphs assigned to Providers.
         Test retrieval of Graphs assigned to Providers.
         """
         """
-        provider = self.model.objects.first()
+        provider = self.model.objects.unrestricted().first()
         ct = ContentType.objects.get(app_label='circuits', model='provider')
         ct = ContentType.objects.get(app_label='circuits', model='provider')
         graphs = (
         graphs = (
             Graph(type=ct, name='Graph 1', source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=1'),
             Graph(type=ct, name='Graph 1', source='http://example.com/graphs.py?provider={{ obj.slug }}&foo=1'),

+ 6 - 6
netbox/dcim/api/views.py

@@ -103,8 +103,8 @@ class SiteViewSet(CustomFieldModelViewSet):
         """
         """
         A convenience method for rendering graphs for a particular site.
         A convenience method for rendering graphs for a particular site.
         """
         """
-        site = get_object_or_404(Site, pk=pk)
-        queryset = Graph.objects.filter(type__model='site')
+        site = get_object_or_404(self.queryset, pk=pk)
+        queryset = Graph.objects.restrict(request.user).filter(type__model='site')
         serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': site})
         serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': site})
         return Response(serializer.data)
         return Response(serializer.data)
 
 
@@ -347,8 +347,8 @@ class DeviceViewSet(CustomFieldModelViewSet):
         """
         """
         A convenience method for rendering graphs for a particular Device.
         A convenience method for rendering graphs for a particular Device.
         """
         """
-        device = get_object_or_404(Device, pk=pk)
-        queryset = Graph.objects.filter(type__model='device')
+        device = get_object_or_404(self.queryset, pk=pk)
+        queryset = Graph.objects.restrict(request.user).filter(type__model='device')
         serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': device})
         serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': device})
 
 
         return Response(serializer.data)
         return Response(serializer.data)
@@ -496,8 +496,8 @@ class InterfaceViewSet(CableTraceMixin, ModelViewSet):
         """
         """
         A convenience method for rendering graphs for a particular interface.
         A convenience method for rendering graphs for a particular interface.
         """
         """
-        interface = get_object_or_404(Interface, pk=pk)
-        queryset = Graph.objects.filter(type__model='interface')
+        interface = get_object_or_404(self.queryset, pk=pk)
+        queryset = Graph.objects.restrict(request.user).filter(type__model='interface')
         serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': interface})
         serializer = RenderedGraphSerializer(queryset, many=True, context={'graphed_object': interface})
         return Response(serializer.data)
         return Response(serializer.data)
 
 

+ 3 - 3
netbox/dcim/tests/test_api.py

@@ -107,7 +107,7 @@ class SiteTest(APIViewTestCases.APIViewTestCase):
         Graph.objects.bulk_create(graphs)
         Graph.objects.bulk_create(graphs)
 
 
         self.add_permissions('dcim.view_site')
         self.add_permissions('dcim.view_site')
-        url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.first().pk})
+        url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.unrestricted().first().pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
         self.assertEqual(len(response.data), 3)
         self.assertEqual(len(response.data), 3)
@@ -878,7 +878,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
         Graph.objects.bulk_create(graphs)
         Graph.objects.bulk_create(graphs)
 
 
         self.add_permissions('dcim.view_device')
         self.add_permissions('dcim.view_device')
-        url = reverse('dcim-api:device-graphs', kwargs={'pk': Device.objects.first().pk})
+        url = reverse('dcim-api:device-graphs', kwargs={'pk': Device.objects.unrestricted().first().pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
         self.assertEqual(len(response.data), 3)
         self.assertEqual(len(response.data), 3)
@@ -1245,7 +1245,7 @@ class InterfaceTest(APIViewTestCases.APIViewTestCase):
         Graph.objects.bulk_create(graphs)
         Graph.objects.bulk_create(graphs)
 
 
         self.add_permissions('dcim.view_interface')
         self.add_permissions('dcim.view_interface')
-        url = reverse('dcim-api:interface-graphs', kwargs={'pk': Interface.objects.first().pk})
+        url = reverse('dcim-api:interface-graphs', kwargs={'pk': Interface.objects.unrestricted().first().pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
         self.assertEqual(len(response.data), 3)
         self.assertEqual(len(response.data), 3)