Просмотр исходного кода

Merge pull request #4733 from netbox-community/4730-api-test-permissions

Closes #4730: Update REST API tests to enforce ObjectPermissions
Jeremy Stretch 5 лет назад
Родитель
Сommit
54dd20cdb4

+ 1 - 0
netbox/circuits/tests/test_api.py

@@ -58,6 +58,7 @@ class ProviderTest(APIViewTestCases.APIViewTestCase):
         )
         )
         Graph.objects.bulk_create(graphs)
         Graph.objects.bulk_create(graphs)
 
 
+        self.add_permissions('circuits.view_provider')
         url = reverse('circuits-api:provider-graphs', kwargs={'pk': provider.pk})
         url = reverse('circuits-api:provider-graphs', kwargs={'pk': provider.pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 

+ 33 - 7
netbox/dcim/tests/test_api.py

@@ -106,6 +106,7 @@ class SiteTest(APIViewTestCases.APIViewTestCase):
         )
         )
         Graph.objects.bulk_create(graphs)
         Graph.objects.bulk_create(graphs)
 
 
+        self.add_permissions('dcim.view_site')
         url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.first().pk})
         url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.first().pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -245,6 +246,7 @@ class RackTest(APIViewTestCases.APIViewTestCase):
     def test_get_elevation_rack_units(self):
     def test_get_elevation_rack_units(self):
         rack = Rack.objects.first()
         rack = Rack.objects.first()
 
 
+        self.add_permissions('dcim.view_rack')
         url = '{}?q=3'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk}))
         url = '{}?q=3'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk}))
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -270,6 +272,7 @@ class RackTest(APIViewTestCases.APIViewTestCase):
         GET a single rack elevation.
         GET a single rack elevation.
         """
         """
         rack = Rack.objects.first()
         rack = Rack.objects.first()
+        self.add_permissions('dcim.view_rack')
         url = reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk})
         url = reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -280,6 +283,7 @@ class RackTest(APIViewTestCases.APIViewTestCase):
         GET a single rack elevation in SVG format.
         GET a single rack elevation in SVG format.
         """
         """
         rack = Rack.objects.first()
         rack = Rack.objects.first()
+        self.add_permissions('dcim.view_rack')
         url = '{}?render=svg'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk}))
         url = '{}?render=svg'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk}))
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -784,6 +788,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
         )
         )
         Graph.objects.bulk_create(graphs)
         Graph.objects.bulk_create(graphs)
 
 
+        self.add_permissions('dcim.view_device')
         url = reverse('dcim-api:device-graphs', kwargs={'pk': Device.objects.first().pk})
         url = reverse('dcim-api:device-graphs', kwargs={'pk': Device.objects.first().pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -794,6 +799,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
         """
         """
         Check that config context data is included by default in the devices list.
         Check that config context data is included by default in the devices list.
         """
         """
+        self.add_permissions('dcim.view_device')
         url = reverse('dcim-api:device-list') + '?slug=device-with-context-data'
         url = reverse('dcim-api:device-list') + '?slug=device-with-context-data'
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -803,6 +809,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
         """
         """
         Check that config context data can be excluded by passing ?exclude=config_context.
         Check that config context data can be excluded by passing ?exclude=config_context.
         """
         """
+        self.add_permissions('dcim.view_device')
         url = reverse('dcim-api:device-list') + '?exclude=config_context'
         url = reverse('dcim-api:device-list') + '?exclude=config_context'
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -820,6 +827,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
             'name': device.name,
             'name': device.name,
         }
         }
 
 
+        self.add_permissions('dcim.add_device')
         url = reverse('dcim-api:device-list')
         url = reverse('dcim-api:device-list')
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
@@ -878,6 +886,7 @@ class ConsolePortTest(APIViewTestCases.APIViewTestCase):
         cable = Cable(termination_a=consoleport, termination_b=consoleserverport, label='Cable 1')
         cable = Cable(termination_a=consoleport, termination_b=consoleserverport, label='Cable 1')
         cable.save()
         cable.save()
 
 
+        self.add_permissions('dcim.view_consoleport')
         url = reverse('dcim-api:consoleport-trace', kwargs={'pk': consoleport.pk})
         url = reverse('dcim-api:consoleport-trace', kwargs={'pk': consoleport.pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -941,6 +950,7 @@ class ConsoleServerPortTest(APIViewTestCases.APIViewTestCase):
         cable = Cable(termination_a=consoleserverport, termination_b=consoleport, label='Cable 1')
         cable = Cable(termination_a=consoleserverport, termination_b=consoleport, label='Cable 1')
         cable.save()
         cable.save()
 
 
+        self.add_permissions('dcim.view_consoleserverport')
         url = reverse('dcim-api:consoleserverport-trace', kwargs={'pk': consoleserverport.pk})
         url = reverse('dcim-api:consoleserverport-trace', kwargs={'pk': consoleserverport.pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -1004,6 +1014,7 @@ class PowerPortTest(APIViewTestCases.APIViewTestCase):
         cable = Cable(termination_a=powerport, termination_b=poweroutlet, label='Cable 1')
         cable = Cable(termination_a=powerport, termination_b=poweroutlet, label='Cable 1')
         cable.save()
         cable.save()
 
 
+        self.add_permissions('dcim.view_powerport')
         url = reverse('dcim-api:powerport-trace', kwargs={'pk': powerport.pk})
         url = reverse('dcim-api:powerport-trace', kwargs={'pk': powerport.pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -1067,6 +1078,7 @@ class PowerOutletTest(APIViewTestCases.APIViewTestCase):
         cable = Cable(termination_a=poweroutlet, termination_b=powerport, label='Cable 1')
         cable = Cable(termination_a=poweroutlet, termination_b=powerport, label='Cable 1')
         cable.save()
         cable.save()
 
 
+        self.add_permissions('dcim.view_poweroutlet')
         url = reverse('dcim-api:poweroutlet-trace', kwargs={'pk': poweroutlet.pk})
         url = reverse('dcim-api:poweroutlet-trace', kwargs={'pk': poweroutlet.pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -1143,6 +1155,7 @@ class InterfaceTest(APIViewTestCases.APIViewTestCase):
         )
         )
         Graph.objects.bulk_create(graphs)
         Graph.objects.bulk_create(graphs)
 
 
+        self.add_permissions('dcim.view_interface')
         url = reverse('dcim-api:interface-graphs', kwargs={'pk': Interface.objects.first().pk})
         url = reverse('dcim-api:interface-graphs', kwargs={'pk': Interface.objects.first().pk})
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
@@ -1446,6 +1459,7 @@ class ConnectionTest(APITestCase):
             'termination_b_id': consoleserverport1.pk,
             'termination_b_id': consoleserverport1.pk,
         }
         }
 
 
+        self.add_permissions('dcim.add_cable')
         url = reverse('dcim-api:cable-list')
         url = reverse('dcim-api:cable-list')
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
@@ -1484,6 +1498,7 @@ class ConnectionTest(APITestCase):
             device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2
             device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2
         )
         )
 
 
+        self.add_permissions('dcim.add_cable')
         url = reverse('dcim-api:cable-list')
         url = reverse('dcim-api:cable-list')
         cables = [
         cables = [
             # Console port to panel1 front
             # Console port to panel1 front
@@ -1539,6 +1554,7 @@ class ConnectionTest(APITestCase):
             'termination_b_id': poweroutlet1.pk,
             'termination_b_id': poweroutlet1.pk,
         }
         }
 
 
+        self.add_permissions('dcim.add_cable')
         url = reverse('dcim-api:cable-list')
         url = reverse('dcim-api:cable-list')
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
@@ -1574,6 +1590,7 @@ class ConnectionTest(APITestCase):
             'termination_b_id': interface2.pk,
             'termination_b_id': interface2.pk,
         }
         }
 
 
+        self.add_permissions('dcim.add_cable')
         url = reverse('dcim-api:cable-list')
         url = reverse('dcim-api:cable-list')
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
@@ -1612,6 +1629,7 @@ class ConnectionTest(APITestCase):
             device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2
             device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2
         )
         )
 
 
+        self.add_permissions('dcim.add_cable')
         url = reverse('dcim-api:cable-list')
         url = reverse('dcim-api:cable-list')
         cables = [
         cables = [
             # Interface1 to panel1 front
             # Interface1 to panel1 front
@@ -1676,6 +1694,7 @@ class ConnectionTest(APITestCase):
             'termination_b_id': circuittermination1.pk,
             'termination_b_id': circuittermination1.pk,
         }
         }
 
 
+        self.add_permissions('dcim.add_cable')
         url = reverse('dcim-api:cable-list')
         url = reverse('dcim-api:cable-list')
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
@@ -1723,6 +1742,7 @@ class ConnectionTest(APITestCase):
             device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2
             device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2
         )
         )
 
 
+        self.add_permissions('dcim.add_cable')
         url = reverse('dcim-api:cable-list')
         url = reverse('dcim-api:cable-list')
         cables = [
         cables = [
             # Interface to panel1 front
             # Interface to panel1 front
@@ -1826,6 +1846,9 @@ class VirtualChassisTest(APIViewTestCases.APIViewTestCase):
             Device(name='Device 7', device_type=devicetype, device_role=devicerole, site=site),
             Device(name='Device 7', device_type=devicetype, device_role=devicerole, site=site),
             Device(name='Device 8', device_type=devicetype, device_role=devicerole, site=site),
             Device(name='Device 8', device_type=devicetype, device_role=devicerole, site=site),
             Device(name='Device 9', device_type=devicetype, device_role=devicerole, site=site),
             Device(name='Device 9', device_type=devicetype, device_role=devicerole, site=site),
+            Device(name='Device 10', device_type=devicetype, device_role=devicerole, site=site),
+            Device(name='Device 11', device_type=devicetype, device_role=devicerole, site=site),
+            Device(name='Device 12', device_type=devicetype, device_role=devicerole, site=site),
         )
         )
         Device.objects.bulk_create(devices)
         Device.objects.bulk_create(devices)
 
 
@@ -1839,16 +1862,19 @@ class VirtualChassisTest(APIViewTestCases.APIViewTestCase):
                 )
                 )
         Interface.objects.bulk_create(interfaces)
         Interface.objects.bulk_create(interfaces)
 
 
-        # Create two VirtualChassis with three members each
+        # Create three VirtualChassis with three members each
         virtual_chassis = (
         virtual_chassis = (
             VirtualChassis(master=devices[0], domain='domain-1'),
             VirtualChassis(master=devices[0], domain='domain-1'),
             VirtualChassis(master=devices[3], domain='domain-2'),
             VirtualChassis(master=devices[3], domain='domain-2'),
+            VirtualChassis(master=devices[6], domain='domain-3'),
         )
         )
         VirtualChassis.objects.bulk_create(virtual_chassis)
         VirtualChassis.objects.bulk_create(virtual_chassis)
         Device.objects.filter(pk=devices[1].pk).update(virtual_chassis=virtual_chassis[0], vc_position=2)
         Device.objects.filter(pk=devices[1].pk).update(virtual_chassis=virtual_chassis[0], vc_position=2)
         Device.objects.filter(pk=devices[2].pk).update(virtual_chassis=virtual_chassis[0], vc_position=3)
         Device.objects.filter(pk=devices[2].pk).update(virtual_chassis=virtual_chassis[0], vc_position=3)
         Device.objects.filter(pk=devices[4].pk).update(virtual_chassis=virtual_chassis[1], vc_position=2)
         Device.objects.filter(pk=devices[4].pk).update(virtual_chassis=virtual_chassis[1], vc_position=2)
         Device.objects.filter(pk=devices[5].pk).update(virtual_chassis=virtual_chassis[1], vc_position=3)
         Device.objects.filter(pk=devices[5].pk).update(virtual_chassis=virtual_chassis[1], vc_position=3)
+        Device.objects.filter(pk=devices[7].pk).update(virtual_chassis=virtual_chassis[2], vc_position=2)
+        Device.objects.filter(pk=devices[8].pk).update(virtual_chassis=virtual_chassis[2], vc_position=3)
 
 
         cls.update_data = {
         cls.update_data = {
             'master': devices[1].pk,
             'master': devices[1].pk,
@@ -1857,17 +1883,17 @@ class VirtualChassisTest(APIViewTestCases.APIViewTestCase):
 
 
         cls.create_data = [
         cls.create_data = [
             {
             {
-                'master': devices[6].pk,
-                'domain': 'domain-3',
-            },
-            {
-                'master': devices[7].pk,
+                'master': devices[9].pk,
                 'domain': 'domain-4',
                 'domain': 'domain-4',
             },
             },
             {
             {
-                'master': devices[8].pk,
+                'master': devices[10].pk,
                 'domain': 'domain-5',
                 'domain': 'domain-5',
             },
             },
+            {
+                'master': devices[11].pk,
+                'domain': 'domain-6',
+            },
         ]
         ]
 
 
 
 

+ 4 - 0
netbox/extras/models/models.py

@@ -232,6 +232,8 @@ class Graph(models.Model):
         verbose_name='Link URL'
         verbose_name='Link URL'
     )
     )
 
 
+    objects = RestrictedQuerySet.as_manager()
+
     class Meta:
     class Meta:
         ordering = ('type', 'weight', 'name', 'pk')  # (type, weight, name) may be non-unique
         ordering = ('type', 'weight', 'name', 'pk')  # (type, weight, name) may be non-unique
 
 
@@ -299,6 +301,8 @@ class ExportTemplate(models.Model):
         help_text='Extension to append to the rendered filename'
         help_text='Extension to append to the rendered filename'
     )
     )
 
 
+    objects = RestrictedQuerySet.as_manager()
+
     class Meta:
     class Meta:
         ordering = ['content_type', 'name']
         ordering = ['content_type', 'name']
         unique_together = [
         unique_together = [

+ 6 - 0
netbox/extras/tests/test_api.py

@@ -295,6 +295,7 @@ class CreatedUpdatedFilterTest(APITestCase):
         )
         )
 
 
     def test_get_rack_created(self):
     def test_get_rack_created(self):
+        self.add_permissions('dcim.view_rack')
         url = reverse('dcim-api:rack-list')
         url = reverse('dcim-api:rack-list')
         response = self.client.get('{}?created=2001-02-03'.format(url), **self.header)
         response = self.client.get('{}?created=2001-02-03'.format(url), **self.header)
 
 
@@ -302,6 +303,7 @@ class CreatedUpdatedFilterTest(APITestCase):
         self.assertEqual(response.data['results'][0]['id'], self.rack2.pk)
         self.assertEqual(response.data['results'][0]['id'], self.rack2.pk)
 
 
     def test_get_rack_created_gte(self):
     def test_get_rack_created_gte(self):
+        self.add_permissions('dcim.view_rack')
         url = reverse('dcim-api:rack-list')
         url = reverse('dcim-api:rack-list')
         response = self.client.get('{}?created__gte=2001-02-04'.format(url), **self.header)
         response = self.client.get('{}?created__gte=2001-02-04'.format(url), **self.header)
 
 
@@ -309,6 +311,7 @@ class CreatedUpdatedFilterTest(APITestCase):
         self.assertEqual(response.data['results'][0]['id'], self.rack1.pk)
         self.assertEqual(response.data['results'][0]['id'], self.rack1.pk)
 
 
     def test_get_rack_created_lte(self):
     def test_get_rack_created_lte(self):
+        self.add_permissions('dcim.view_rack')
         url = reverse('dcim-api:rack-list')
         url = reverse('dcim-api:rack-list')
         response = self.client.get('{}?created__lte=2001-02-04'.format(url), **self.header)
         response = self.client.get('{}?created__lte=2001-02-04'.format(url), **self.header)
 
 
@@ -316,6 +319,7 @@ class CreatedUpdatedFilterTest(APITestCase):
         self.assertEqual(response.data['results'][0]['id'], self.rack2.pk)
         self.assertEqual(response.data['results'][0]['id'], self.rack2.pk)
 
 
     def test_get_rack_last_updated(self):
     def test_get_rack_last_updated(self):
+        self.add_permissions('dcim.view_rack')
         url = reverse('dcim-api:rack-list')
         url = reverse('dcim-api:rack-list')
         response = self.client.get('{}?last_updated=2001-02-03%2001:02:03.000004'.format(url), **self.header)
         response = self.client.get('{}?last_updated=2001-02-03%2001:02:03.000004'.format(url), **self.header)
 
 
@@ -323,6 +327,7 @@ class CreatedUpdatedFilterTest(APITestCase):
         self.assertEqual(response.data['results'][0]['id'], self.rack2.pk)
         self.assertEqual(response.data['results'][0]['id'], self.rack2.pk)
 
 
     def test_get_rack_last_updated_gte(self):
     def test_get_rack_last_updated_gte(self):
+        self.add_permissions('dcim.view_rack')
         url = reverse('dcim-api:rack-list')
         url = reverse('dcim-api:rack-list')
         response = self.client.get('{}?last_updated__gte=2001-02-04%2001:02:03.000004'.format(url), **self.header)
         response = self.client.get('{}?last_updated__gte=2001-02-04%2001:02:03.000004'.format(url), **self.header)
 
 
@@ -330,6 +335,7 @@ class CreatedUpdatedFilterTest(APITestCase):
         self.assertEqual(response.data['results'][0]['id'], self.rack1.pk)
         self.assertEqual(response.data['results'][0]['id'], self.rack1.pk)
 
 
     def test_get_rack_last_updated_lte(self):
     def test_get_rack_last_updated_lte(self):
+        self.add_permissions('dcim.view_rack')
         url = reverse('dcim-api:rack-list')
         url = reverse('dcim-api:rack-list')
         response = self.client.get('{}?last_updated__lte=2001-02-04%2001:02:03.000004'.format(url), **self.header)
         response = self.client.get('{}?last_updated__lte=2001-02-04%2001:02:03.000004'.format(url), **self.header)
 
 

+ 6 - 11
netbox/extras/tests/test_changelog.py

@@ -4,7 +4,6 @@ from rest_framework import status
 
 
 from dcim.models import Site
 from dcim.models import Site
 from extras.choices import *
 from extras.choices import *
-from extras.constants import *
 from extras.models import CustomField, CustomFieldValue, ObjectChange
 from extras.models import CustomField, CustomFieldValue, ObjectChange
 from utilities.testing import APITestCase
 from utilities.testing import APITestCase
 
 
@@ -26,7 +25,6 @@ class ChangeLogTest(APITestCase):
         cf.obj_type.set([ct])
         cf.obj_type.set([ct])
 
 
     def test_create_object(self):
     def test_create_object(self):
-
         data = {
         data = {
             'name': 'Test Site 1',
             'name': 'Test Site 1',
             'slug': 'test-site-1',
             'slug': 'test-site-1',
@@ -37,10 +35,10 @@ class ChangeLogTest(APITestCase):
                 'bar', 'foo'
                 'bar', 'foo'
             ],
             ],
         }
         }
-
         self.assertEqual(ObjectChange.objects.count(), 0)
         self.assertEqual(ObjectChange.objects.count(), 0)
-
         url = reverse('dcim-api:site-list')
         url = reverse('dcim-api:site-list')
+        self.add_permissions('dcim.add_site')
+
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
 
 
@@ -55,7 +53,6 @@ class ChangeLogTest(APITestCase):
         self.assertListEqual(sorted(oc.object_data['tags']), data['tags'])
         self.assertListEqual(sorted(oc.object_data['tags']), data['tags'])
 
 
     def test_update_object(self):
     def test_update_object(self):
-
         site = Site(name='Test Site 1', slug='test-site-1')
         site = Site(name='Test Site 1', slug='test-site-1')
         site.save()
         site.save()
 
 
@@ -69,10 +66,10 @@ class ChangeLogTest(APITestCase):
                 'abc', 'xyz'
                 'abc', 'xyz'
             ],
             ],
         }
         }
-
         self.assertEqual(ObjectChange.objects.count(), 0)
         self.assertEqual(ObjectChange.objects.count(), 0)
-
+        self.add_permissions('dcim.change_site')
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
+
         response = self.client.put(url, data, format='json', **self.header)
         response = self.client.put(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
 
 
@@ -87,7 +84,6 @@ class ChangeLogTest(APITestCase):
         self.assertListEqual(sorted(oc.object_data['tags']), data['tags'])
         self.assertListEqual(sorted(oc.object_data['tags']), data['tags'])
 
 
     def test_delete_object(self):
     def test_delete_object(self):
-
         site = Site(
         site = Site(
             name='Test Site 1',
             name='Test Site 1',
             slug='test-site-1'
             slug='test-site-1'
@@ -99,12 +95,11 @@ class ChangeLogTest(APITestCase):
             obj=site,
             obj=site,
             value='ABC'
             value='ABC'
         )
         )
-
         self.assertEqual(ObjectChange.objects.count(), 0)
         self.assertEqual(ObjectChange.objects.count(), 0)
-
+        self.add_permissions('dcim.delete_site')
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
-        response = self.client.delete(url, **self.header)
 
 
+        response = self.client.delete(url, **self.header)
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
         self.assertEqual(Site.objects.count(), 0)
         self.assertEqual(Site.objects.count(), 0)
 
 

+ 14 - 8
netbox/extras/tests/test_customfields.py

@@ -182,8 +182,9 @@ class CustomFieldAPITest(APITestCase):
         Validate that custom fields are present on an object even if it has no values defined.
         Validate that custom fields are present on an object even if it has no values defined.
         """
         """
         url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[0].pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[0].pk})
-        response = self.client.get(url, **self.header)
+        self.add_permissions('dcim.view_site')
 
 
+        response = self.client.get(url, **self.header)
         self.assertEqual(response.data['name'], self.sites[0].name)
         self.assertEqual(response.data['name'], self.sites[0].name)
         self.assertEqual(response.data['custom_fields'], {
         self.assertEqual(response.data['custom_fields'], {
             'text_field': None,
             'text_field': None,
@@ -201,10 +202,10 @@ class CustomFieldAPITest(APITestCase):
         site2_cfvs = {
         site2_cfvs = {
             cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all()
             cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all()
         }
         }
-
         url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk})
-        response = self.client.get(url, **self.header)
+        self.add_permissions('dcim.view_site')
 
 
+        response = self.client.get(url, **self.header)
         self.assertEqual(response.data['name'], self.sites[1].name)
         self.assertEqual(response.data['name'], self.sites[1].name)
         self.assertEqual(response.data['custom_fields']['text_field'], site2_cfvs['text_field'])
         self.assertEqual(response.data['custom_fields']['text_field'], site2_cfvs['text_field'])
         self.assertEqual(response.data['custom_fields']['number_field'], site2_cfvs['number_field'])
         self.assertEqual(response.data['custom_fields']['number_field'], site2_cfvs['number_field'])
@@ -221,8 +222,9 @@ class CustomFieldAPITest(APITestCase):
             'name': 'Site 3',
             'name': 'Site 3',
             'slug': 'site-3',
             'slug': 'site-3',
         }
         }
-
         url = reverse('dcim-api:site-list')
         url = reverse('dcim-api:site-list')
+        self.add_permissions('dcim.add_site')
+
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
 
 
@@ -263,8 +265,9 @@ class CustomFieldAPITest(APITestCase):
                 'choice_field': self.cf_select_choice2.pk,
                 'choice_field': self.cf_select_choice2.pk,
             },
             },
         }
         }
-
         url = reverse('dcim-api:site-list')
         url = reverse('dcim-api:site-list')
+        self.add_permissions('dcim.add_site')
+
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
 
 
@@ -309,8 +312,9 @@ class CustomFieldAPITest(APITestCase):
                 'slug': 'site-5',
                 'slug': 'site-5',
             },
             },
         )
         )
-
         url = reverse('dcim-api:site-list')
         url = reverse('dcim-api:site-list')
+        self.add_permissions('dcim.add_site')
+
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(len(response.data), len(data))
         self.assertEqual(len(response.data), len(data))
@@ -367,8 +371,9 @@ class CustomFieldAPITest(APITestCase):
                 'custom_fields': custom_field_data,
                 'custom_fields': custom_field_data,
             },
             },
         )
         )
-
         url = reverse('dcim-api:site-list')
         url = reverse('dcim-api:site-list')
+        self.add_permissions('dcim.add_site')
+
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(len(response.data), len(data))
         self.assertEqual(len(response.data), len(data))
@@ -410,8 +415,9 @@ class CustomFieldAPITest(APITestCase):
                 'number_field': 1234,
                 'number_field': 1234,
             },
             },
         }
         }
-
         url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk})
+        self.add_permissions('dcim.change_site')
+
         response = self.client.patch(url, data, format='json', **self.header)
         response = self.client.patch(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
 
 

+ 4 - 7
netbox/extras/tests/test_tags.py

@@ -15,16 +15,15 @@ class TaggedItemTest(APITestCase):
         super().setUp()
         super().setUp()
 
 
     def test_create_tagged_item(self):
     def test_create_tagged_item(self):
-
         data = {
         data = {
             'name': 'Test Site',
             'name': 'Test Site',
             'slug': 'test-site',
             'slug': 'test-site',
             'tags': ['Foo', 'Bar', 'Baz']
             'tags': ['Foo', 'Bar', 'Baz']
         }
         }
-
         url = reverse('dcim-api:site-list')
         url = reverse('dcim-api:site-list')
-        response = self.client.post(url, data, format='json', **self.header)
+        self.add_permissions('dcim.add_site')
 
 
+        response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(sorted(response.data['tags']), sorted(data['tags']))
         self.assertEqual(sorted(response.data['tags']), sorted(data['tags']))
         site = Site.objects.get(pk=response.data['id'])
         site = Site.objects.get(pk=response.data['id'])
@@ -32,20 +31,18 @@ class TaggedItemTest(APITestCase):
         self.assertEqual(sorted(tags), sorted(data['tags']))
         self.assertEqual(sorted(tags), sorted(data['tags']))
 
 
     def test_update_tagged_item(self):
     def test_update_tagged_item(self):
-
         site = Site.objects.create(
         site = Site.objects.create(
             name='Test Site',
             name='Test Site',
             slug='test-site'
             slug='test-site'
         )
         )
         site.tags.add('Foo', 'Bar', 'Baz')
         site.tags.add('Foo', 'Bar', 'Baz')
-
         data = {
         data = {
             'tags': ['Foo', 'Bar', 'New Tag']
             'tags': ['Foo', 'Bar', 'New Tag']
         }
         }
-
+        self.add_permissions('dcim.change_site')
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
-        response = self.client.patch(url, data, format='json', **self.header)
 
 
+        response = self.client.patch(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertEqual(sorted(response.data['tags']), sorted(data['tags']))
         self.assertEqual(sorted(response.data['tags']), sorted(data['tags']))
         site = Site.objects.get(pk=response.data['id'])
         site = Site.objects.get(pk=response.data['id'])

+ 5 - 7
netbox/extras/tests/test_webhooks.py

@@ -42,13 +42,13 @@ class WebhookTest(APITestCase):
             webhook.obj_type.set([site_ct])
             webhook.obj_type.set([site_ct])
 
 
     def test_enqueue_webhook_create(self):
     def test_enqueue_webhook_create(self):
-
         # Create an object via the REST API
         # Create an object via the REST API
         data = {
         data = {
             'name': 'Test Site',
             'name': 'Test Site',
             'slug': 'test-site',
             'slug': 'test-site',
         }
         }
         url = reverse('dcim-api:site-list')
         url = reverse('dcim-api:site-list')
+        self.add_permissions('dcim.add_site')
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(Site.objects.count(), 1)
         self.assertEqual(Site.objects.count(), 1)
@@ -62,14 +62,13 @@ class WebhookTest(APITestCase):
         self.assertEqual(job.args[3], ObjectChangeActionChoices.ACTION_CREATE)
         self.assertEqual(job.args[3], ObjectChangeActionChoices.ACTION_CREATE)
 
 
     def test_enqueue_webhook_update(self):
     def test_enqueue_webhook_update(self):
-
-        site = Site.objects.create(name='Site 1', slug='site-1')
-
         # Update an object via the REST API
         # Update an object via the REST API
+        site = Site.objects.create(name='Site 1', slug='site-1')
         data = {
         data = {
             'comments': 'Updated the site',
             'comments': 'Updated the site',
         }
         }
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
+        self.add_permissions('dcim.change_site')
         response = self.client.patch(url, data, format='json', **self.header)
         response = self.client.patch(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
 
 
@@ -82,11 +81,10 @@ class WebhookTest(APITestCase):
         self.assertEqual(job.args[3], ObjectChangeActionChoices.ACTION_UPDATE)
         self.assertEqual(job.args[3], ObjectChangeActionChoices.ACTION_UPDATE)
 
 
     def test_enqueue_webhook_delete(self):
     def test_enqueue_webhook_delete(self):
-
-        site = Site.objects.create(name='Site 1', slug='site-1')
-
         # Delete an object via the REST API
         # Delete an object via the REST API
+        site = Site.objects.create(name='Site 1', slug='site-1')
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
+        self.add_permissions('dcim.delete_site')
         response = self.client.delete(url, **self.header)
         response = self.client.delete(url, **self.header)
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
 
 

+ 11 - 0
netbox/ipam/tests/test_api.py

@@ -176,6 +176,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
         Prefix.objects.create(prefix=IPNetwork('192.0.2.64/26'))
         Prefix.objects.create(prefix=IPNetwork('192.0.2.64/26'))
         Prefix.objects.create(prefix=IPNetwork('192.0.2.192/27'))
         Prefix.objects.create(prefix=IPNetwork('192.0.2.192/27'))
         url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk})
         url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk})
+        self.add_permissions('ipam.view_prefix')
 
 
         # Retrieve all available IPs
         # Retrieve all available IPs
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
@@ -190,6 +191,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
         vrf = VRF.objects.create(name='Test VRF 1', rd='1234')
         vrf = VRF.objects.create(name='Test VRF 1', rd='1234')
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/28'), vrf=vrf, is_pool=True)
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/28'), vrf=vrf, is_pool=True)
         url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk})
         url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk})
+        self.add_permissions('ipam.add_prefix')
 
 
         # Create four available prefixes with individual requests
         # Create four available prefixes with individual requests
         prefixes_to_be_created = [
         prefixes_to_be_created = [
@@ -225,6 +227,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
         """
         """
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/28'), is_pool=True)
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/28'), is_pool=True)
         url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk})
         url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk})
+        self.add_permissions('ipam.view_prefix', 'ipam.add_prefix')
 
 
         # Try to create five /30s (only four are available)
         # Try to create five /30s (only four are available)
         data = [
         data = [
@@ -240,6 +243,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
 
 
         # Verify that no prefixes were created (the entire /28 is still available)
         # Verify that no prefixes were created (the entire /28 is still available)
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertEqual(response.data[0]['prefix'], '192.0.2.0/28')
         self.assertEqual(response.data[0]['prefix'], '192.0.2.0/28')
 
 
         # Create four /30s in a single request
         # Create four /30s in a single request
@@ -253,6 +257,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
         """
         """
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/29'), is_pool=True)
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/29'), is_pool=True)
         url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk})
         url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk})
+        self.add_permissions('ipam.view_prefix', 'ipam.view_ipaddress')
 
 
         # Retrieve all available IPs
         # Retrieve all available IPs
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
@@ -271,6 +276,8 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
         vrf = VRF.objects.create(name='Test VRF 1', rd='1234')
         vrf = VRF.objects.create(name='Test VRF 1', rd='1234')
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/30'), vrf=vrf, is_pool=True)
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/30'), vrf=vrf, is_pool=True)
         url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk})
         url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk})
+        # TODO: ipam.add_prefix should not be required
+        self.add_permissions('ipam.add_prefix', 'ipam.add_ipaddress')
 
 
         # Create all four available IPs with individual requests
         # Create all four available IPs with individual requests
         for i in range(1, 5):
         for i in range(1, 5):
@@ -293,6 +300,8 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
         """
         """
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/29'), is_pool=True)
         prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/29'), is_pool=True)
         url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk})
         url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk})
+        # TODO: ipam.add_prefix, ipam.view_prefix should not be required
+        self.add_permissions('ipam.add_prefix', 'ipam.view_prefix', 'ipam.view_ipaddress', 'ipam.add_ipaddress')
 
 
         # Try to create nine IPs (only eight are available)
         # Try to create nine IPs (only eight are available)
         data = [{'description': 'Test IP {}'.format(i)} for i in range(1, 10)]  # 9 IPs
         data = [{'description': 'Test IP {}'.format(i)} for i in range(1, 10)]  # 9 IPs
@@ -302,6 +311,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
 
 
         # Verify that no IPs were created (eight are still available)
         # Verify that no IPs were created (eight are still available)
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertEqual(len(response.data), 8)
         self.assertEqual(len(response.data), 8)
 
 
         # Create all eight available IPs in a single request
         # Create all eight available IPs in a single request
@@ -411,6 +421,7 @@ class VLANTest(APIViewTestCases.APIViewTestCase):
         vlan = VLAN.objects.first()
         vlan = VLAN.objects.first()
         Prefix.objects.create(prefix=IPNetwork('192.0.2.0/24'), vlan=vlan)
         Prefix.objects.create(prefix=IPNetwork('192.0.2.0/24'), vlan=vlan)
 
 
+        self.add_permissions('ipam.delete_vlan')
         url = reverse('ipam-api:vlan-detail', kwargs={'pk': vlan.pk})
         url = reverse('ipam-api:vlan-detail', kwargs={'pk': vlan.pk})
         with disable_warnings('django.request'):
         with disable_warnings('django.request'):
             response = self.client.delete(url, **self.header)
             response = self.client.delete(url, **self.header)

+ 1 - 1
netbox/netbox/tests/test_authentication.py

@@ -11,7 +11,7 @@ from dcim.models import Site
 from ipam.choices import PrefixStatusChoices
 from ipam.choices import PrefixStatusChoices
 from ipam.models import Prefix
 from ipam.models import Prefix
 from users.models import ObjectPermission, Token
 from users.models import ObjectPermission, Token
-from utilities.testing.testcases import TestCase
+from utilities.testing import TestCase
 
 
 
 
 class ExternalAuthenticationTestCase(TestCase):
 class ExternalAuthenticationTestCase(TestCase):

+ 0 - 1
netbox/secrets/api/views.py

@@ -29,7 +29,6 @@ class SecretRoleViewSet(ModelViewSet):
         secret_count=Count('secrets')
         secret_count=Count('secrets')
     )
     )
     serializer_class = serializers.SecretRoleSerializer
     serializer_class = serializers.SecretRoleSerializer
-    permission_classes = [IsAuthenticated]
     filterset_class = filters.SecretRoleFilterSet
     filterset_class = filters.SecretRoleFilterSet
 
 
 
 

+ 26 - 6
netbox/users/api/nested_serializers.py

@@ -1,16 +1,17 @@
 from django.contrib.auth.models import Group, User
 from django.contrib.auth.models import Group, User
+from django.contrib.contenttypes.models import ContentType
+from rest_framework import serializers
 
 
-from utilities.api import WritableNestedSerializer
+from users.models import ObjectPermission
+from utilities.api import ContentTypeField, WritableNestedSerializer
 
 
-_all_ = [
+__all__ = [
+    'NestedGroupSerializer',
+    'NestedObjectPermissionSerializer',
     'NestedUserSerializer',
     'NestedUserSerializer',
 ]
 ]
 
 
 
 
-#
-# Groups and users
-#
-
 class NestedGroupSerializer(WritableNestedSerializer):
 class NestedGroupSerializer(WritableNestedSerializer):
 
 
     class Meta:
     class Meta:
@@ -23,3 +24,22 @@ class NestedUserSerializer(WritableNestedSerializer):
     class Meta:
     class Meta:
         model = User
         model = User
         fields = ['id', 'username']
         fields = ['id', 'username']
+
+
+class NestedObjectPermissionSerializer(WritableNestedSerializer):
+    object_types = ContentTypeField(
+        queryset=ContentType.objects.all(),
+        many=True
+    )
+    groups = serializers.SerializerMethodField(read_only=True)
+    users = serializers.SerializerMethodField(read_only=True)
+
+    class Meta:
+        model = ObjectPermission
+        fields = ['id', 'object_types', 'groups', 'users', 'actions']
+
+    def get_groups(self, obj):
+        return [g.name for g in obj.groups.all()]
+
+    def get_users(self, obj):
+        return [u.username for u in obj.users.all()]

+ 1 - 0
netbox/users/api/serializers.py

@@ -1,3 +1,4 @@
+from django.contrib.auth.models import Group, User
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 
 
 from users.models import ObjectPermission
 from users.models import ObjectPermission

+ 3 - 0
netbox/users/models.py

@@ -10,6 +10,7 @@ from django.db.models.signals import post_save
 from django.dispatch import receiver
 from django.dispatch import receiver
 from django.utils import timezone
 from django.utils import timezone
 
 
+from utilities.querysets import RestrictedQuerySet
 from utilities.utils import flatten_dict
 from utilities.utils import flatten_dict
 
 
 
 
@@ -262,6 +263,8 @@ class ObjectPermission(models.Model):
         help_text="Queryset filter matching the applicable objects of the selected type(s)"
         help_text="Queryset filter matching the applicable objects of the selected type(s)"
     )
     )
 
 
+    objects = RestrictedQuerySet.as_manager()
+
     class Meta:
     class Meta:
         verbose_name = "Permission"
         verbose_name = "Permission"
 
 

+ 5 - 75
netbox/users/tests/test_api.py

@@ -1,10 +1,9 @@
 from django.contrib.auth.models import Group, User
 from django.contrib.auth.models import Group, User
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.urls import reverse
 from django.urls import reverse
-from rest_framework import status
 
 
 from users.models import ObjectPermission
 from users.models import ObjectPermission
-from utilities.testing import APITestCase
+from utilities.testing import APIViewTestCases, APITestCase
 
 
 
 
 class AppTest(APITestCase):
 class AppTest(APITestCase):
@@ -17,7 +16,9 @@ class AppTest(APITestCase):
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.status_code, 200)
 
 
 
 
-class ObjectPermissionTest(APITestCase):
+class ObjectPermissionTest(APIViewTestCases.APIViewTestCase):
+    model = ObjectPermission
+    brief_fields = ['actions', 'groups', 'id', 'object_types', 'users']
 
 
     @classmethod
     @classmethod
     def setUpTestData(cls):
     def setUpTestData(cls):
@@ -48,43 +49,7 @@ class ObjectPermissionTest(APITestCase):
             objectpermission.groups.add(groups[i])
             objectpermission.groups.add(groups[i])
             objectpermission.users.add(users[i])
             objectpermission.users.add(users[i])
 
 
-    def test_get_objectpermission(self):
-        objectpermission = ObjectPermission.objects.first()
-        url = reverse('users-api:objectpermission-detail', kwargs={'pk': objectpermission.pk})
-        response = self.client.get(url, **self.header)
-
-        self.assertEqual(response.data['id'], objectpermission.pk)
-
-    def test_list_objectpermissions(self):
-        url = reverse('users-api:objectpermission-list')
-        response = self.client.get(url, **self.header)
-
-        self.assertEqual(response.data['count'], ObjectPermission.objects.count())
-
-    def test_create_objectpermission(self):
-        data = {
-            'object_types': ['dcim.site'],
-            'groups': [Group.objects.first().pk],
-            'users': [User.objects.first().pk],
-            'actions': ['view', 'add', 'change', 'delete'],
-            'constraints': {'name': 'TEST4'},
-        }
-
-        url = reverse('users-api:objectpermission-list')
-        response = self.client.post(url, data, format='json', **self.header)
-
-        self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(ObjectPermission.objects.count(), 4)
-        objectpermission = ObjectPermission.objects.get(pk=response.data['id'])
-        self.assertEqual(objectpermission.groups.first().pk, data['groups'][0])
-        self.assertEqual(objectpermission.users.first().pk, data['users'][0])
-        self.assertEqual(objectpermission.actions, data['actions'])
-        self.assertEqual(objectpermission.constraints, data['constraints'])
-
-    def test_create_objectpermission_bulk(self):
-        groups = Group.objects.all()[:3]
-        users = User.objects.all()[:3]
-        data = [
+        cls.create_data = [
             {
             {
                 'object_types': ['dcim.site'],
                 'object_types': ['dcim.site'],
                 'groups': [groups[0].pk],
                 'groups': [groups[0].pk],
@@ -107,38 +72,3 @@ class ObjectPermissionTest(APITestCase):
                 'constraints': {'name': 'TEST6'},
                 'constraints': {'name': 'TEST6'},
             },
             },
         ]
         ]
-
-        url = reverse('users-api:objectpermission-list')
-        response = self.client.post(url, data, format='json', **self.header)
-
-        self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(ObjectPermission.objects.count(), 6)
-
-    def test_update_objectpermission(self):
-        objectpermission = ObjectPermission.objects.first()
-        data = {
-            'object_types': ['dcim.site', 'dcim.device'],
-            'groups': [g.pk for g in Group.objects.all()[:2]],
-            'users': [u.pk for u in User.objects.all()[:2]],
-            'actions': ['view'],
-            'constraints': {'name': 'TEST'},
-        }
-
-        url = reverse('users-api:objectpermission-detail', kwargs={'pk': objectpermission.pk})
-        response = self.client.put(url, data, format='json', **self.header)
-
-        self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(ObjectPermission.objects.count(), 3)
-        objectpermission = ObjectPermission.objects.get(pk=response.data['id'])
-        self.assertEqual(objectpermission.groups.first().pk, data['groups'][0])
-        self.assertEqual(objectpermission.users.first().pk, data['users'][0])
-        self.assertEqual(objectpermission.actions, data['actions'])
-        self.assertEqual(objectpermission.constraints, data['constraints'])
-
-    def test_delete_objectpermission(self):
-        objectpermission = ObjectPermission.objects.first()
-        url = reverse('users-api:objectpermission-detail', kwargs={'pk': objectpermission.pk})
-        response = self.client.delete(url, **self.header)
-
-        self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(ObjectPermission.objects.count(), 2)

+ 2 - 1
netbox/utilities/testing/__init__.py

@@ -1,2 +1,3 @@
-from .testcases import *
+from .api import *
 from .utils import *
 from .utils import *
+from .views import *

+ 282 - 0
netbox/utilities/testing/api.py

@@ -0,0 +1,282 @@
+from django.contrib.auth.models import User
+from django.contrib.contenttypes.models import ContentType
+from django.urls import reverse
+from django.test import override_settings
+from rest_framework import status
+from rest_framework.test import APIClient
+
+from users.models import ObjectPermission, Token
+from .utils import disable_warnings
+from .views import TestCase
+
+
+__all__ = (
+    'APITestCase',
+    'APIViewTestCases',
+)
+
+
+#
+# REST API Tests
+#
+
+class APITestCase(TestCase):
+    client_class = APIClient
+    model = None
+
+    def setUp(self):
+        """
+        Create a superuser and token for API calls.
+        """
+        # Create the test user and assign permissions
+        self.user = User.objects.create_user(username='testuser')
+        self.add_permissions(*self.user_permissions)
+        self.token = Token.objects.create(user=self.user)
+        self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
+
+    def _get_detail_url(self, instance):
+        viewname = f'{instance._meta.app_label}-api:{instance._meta.model_name}-detail'
+        return reverse(viewname, kwargs={'pk': instance.pk})
+
+    def _get_list_url(self):
+        viewname = f'{self.model._meta.app_label}-api:{self.model._meta.model_name}-list'
+        return reverse(viewname)
+
+
+class APIViewTestCases:
+
+    class GetObjectViewTestCase(APITestCase):
+
+        @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
+        def test_get_object_anonymous(self):
+            """
+            GET a single object as an unauthenticated user.
+            """
+            url = self._get_detail_url(self.model.objects.first())
+            response = self.client.get(url, **self.header)
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        @override_settings(EXEMPT_VIEW_PERMISSIONS=[])
+        def test_get_object_without_permission(self):
+            """
+            GET a single object as an authenticated user without the required permission.
+            """
+            url = self._get_detail_url(self.model.objects.first())
+
+            # Try GET without permission
+            with disable_warnings('django.request'):
+                self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN)
+
+        @override_settings(EXEMPT_VIEW_PERMISSIONS=[])
+        def test_get_object(self):
+            """
+            GET a single object as an authenticated user with permission to view the object.
+            """
+            self.assertGreaterEqual(self.model.objects.count(), 2,
+                                    f"Test requires the creation of at least two {self.model} instances")
+            instance1, instance2 = self.model.objects.all()[:2]
+
+            # Add object-level permission
+            obj_perm = ObjectPermission(
+                constraints={'pk': instance1.pk},
+                actions=['view']
+            )
+            obj_perm.save()
+            obj_perm.users.add(self.user)
+            obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
+
+            # Try GET to permitted object
+            url = self._get_detail_url(instance1)
+            self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_200_OK)
+
+            # Try GET to non-permitted object
+            url = self._get_detail_url(instance2)
+            self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_404_NOT_FOUND)
+
+    class ListObjectsViewTestCase(APITestCase):
+        brief_fields = []
+
+        @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
+        def test_list_objects_anonymous(self):
+            """
+            GET a list of objects as an unauthenticated user.
+            """
+            url = self._get_list_url()
+            response = self.client.get(url, **self.header)
+
+            self.assertEqual(len(response.data['results']), self.model.objects.count())
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
+        def test_list_objects_brief(self):
+            """
+            GET a list of objects using the "brief" parameter as an unauthenticated user.
+            """
+            url = f'{self._get_list_url()}?brief=1'
+            response = self.client.get(url, **self.header)
+
+            self.assertEqual(len(response.data['results']), self.model.objects.count())
+            self.assertEqual(sorted(response.data['results'][0]), self.brief_fields)
+
+        @override_settings(EXEMPT_VIEW_PERMISSIONS=[])
+        def test_list_objects_without_permission(self):
+            """
+            GET a list of objects as an authenticated user without the required permission.
+            """
+            url = self._get_list_url()
+
+            # Try GET without permission
+            with disable_warnings('django.request'):
+                self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN)
+
+        @override_settings(EXEMPT_VIEW_PERMISSIONS=[])
+        def test_list_objects(self):
+            """
+            GET a list of objects as an authenticated user with permission to view the objects.
+            """
+            self.assertGreaterEqual(self.model.objects.count(), 3,
+                                    f"Test requires the creation of at least three {self.model} instances")
+            instance1, instance2 = self.model.objects.all()[:2]
+
+            # Add object-level permission
+            obj_perm = ObjectPermission(
+                constraints={'pk__in': [instance1.pk, instance2.pk]},
+                actions=['view']
+            )
+            obj_perm.save()
+            obj_perm.users.add(self.user)
+            obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
+
+            # Try GET to permitted objects
+            response = self.client.get(self._get_list_url(), **self.header)
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+            self.assertEqual(len(response.data['results']), 2)
+
+    class CreateObjectViewTestCase(APITestCase):
+        create_data = []
+
+        def test_create_object_without_permission(self):
+            """
+            POST a single object without permission.
+            """
+            url = self._get_list_url()
+
+            # Try POST without permission
+            with disable_warnings('django.request'):
+                response = self.client.post(url, self.create_data[0], format='json', **self.header)
+                self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
+
+        def test_create_object(self):
+            """
+            POST a single object with permission.
+            """
+            # Add object-level permission
+            obj_perm = ObjectPermission(
+                actions=['add']
+            )
+            obj_perm.save()
+            obj_perm.users.add(self.user)
+            obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
+
+            initial_count = self.model.objects.count()
+            response = self.client.post(self._get_list_url(), self.create_data[0], format='json', **self.header)
+            self.assertHttpStatus(response, status.HTTP_201_CREATED)
+            self.assertEqual(self.model.objects.count(), initial_count + 1)
+            self.assertInstanceEqual(self.model.objects.get(pk=response.data['id']), self.create_data[0], api=True)
+
+        def test_bulk_create_objects(self):
+            """
+            POST a set of objects in a single request.
+            """
+            # Add object-level permission
+            obj_perm = ObjectPermission(
+                actions=['add']
+            )
+            obj_perm.save()
+            obj_perm.users.add(self.user)
+            obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
+
+            initial_count = self.model.objects.count()
+            response = self.client.post(self._get_list_url(), self.create_data, format='json', **self.header)
+            self.assertHttpStatus(response, status.HTTP_201_CREATED)
+            self.assertEqual(len(response.data), len(self.create_data))
+            self.assertEqual(self.model.objects.count(), initial_count + len(self.create_data))
+            for i, obj in enumerate(response.data):
+                self.assertInstanceEqual(self.model.objects.get(pk=obj['id']), self.create_data[i], api=True)
+
+    class UpdateObjectViewTestCase(APITestCase):
+        update_data = {}
+
+        def test_update_object_without_permission(self):
+            """
+            PATCH a single object without permission.
+            """
+            url = self._get_detail_url(self.model.objects.first())
+            update_data = self.update_data or getattr(self, 'create_data')[0]
+
+            # Try PATCH without permission
+            with disable_warnings('django.request'):
+                response = self.client.patch(url, update_data, format='json', **self.header)
+                self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
+
+        def test_update_object(self):
+            """
+            PATCH a single object identified by its numeric ID.
+            """
+            instance = self.model.objects.first()
+            url = self._get_detail_url(instance)
+            update_data = self.update_data or getattr(self, 'create_data')[0]
+
+            # Add object-level permission
+            obj_perm = ObjectPermission(
+                actions=['change']
+            )
+            obj_perm.save()
+            obj_perm.users.add(self.user)
+            obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
+
+            response = self.client.patch(url, update_data, format='json', **self.header)
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+            instance.refresh_from_db()
+            self.assertInstanceEqual(instance, self.update_data, api=True)
+
+    class DeleteObjectViewTestCase(APITestCase):
+
+        def test_delete_object_without_permission(self):
+            """
+            DELETE a single object without permission.
+            """
+            url = self._get_detail_url(self.model.objects.first())
+
+            # Try DELETE without permission
+            with disable_warnings('django.request'):
+                response = self.client.delete(url, **self.header)
+                self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
+
+        def test_delete_object(self):
+            """
+            DELETE a single object identified by its numeric ID.
+            """
+            instance = self.model.objects.first()
+            url = self._get_detail_url(instance)
+
+            # Add object-level permission
+            obj_perm = ObjectPermission(
+                actions=['delete']
+            )
+            obj_perm.save()
+            obj_perm.users.add(self.user)
+            obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
+
+            response = self.client.delete(url, **self.header)
+            self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
+            self.assertFalse(self.model.objects.filter(pk=instance.pk).exists())
+
+    class APIViewTestCase(
+        GetObjectViewTestCase,
+        ListObjectsViewTestCase,
+        CreateObjectViewTestCase,
+        UpdateObjectViewTestCase,
+        DeleteObjectViewTestCase
+    ):
+        pass

+ 14 - 138
netbox/utilities/testing/testcases.py → netbox/utilities/testing/views.py

@@ -1,18 +1,24 @@
 from django.contrib.auth.models import User
 from django.contrib.auth.models import User
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.core.exceptions import ObjectDoesNotExist
 from django.core.exceptions import ObjectDoesNotExist
+from django.db.models import ForeignKey, ManyToManyField
 from django.forms.models import model_to_dict
 from django.forms.models import model_to_dict
 from django.test import Client, TestCase as _TestCase, override_settings
 from django.test import Client, TestCase as _TestCase, override_settings
 from django.urls import reverse, NoReverseMatch
 from django.urls import reverse, NoReverseMatch
 from netaddr import IPNetwork
 from netaddr import IPNetwork
-from rest_framework import status
-from rest_framework.test import APIClient
 
 
-from users.models import ObjectPermission, Token
+from users.models import ObjectPermission
 from utilities.permissions import resolve_permission_ct
 from utilities.permissions import resolve_permission_ct
 from .utils import disable_warnings, post_data
 from .utils import disable_warnings, post_data
 
 
 
 
+__all__ = (
+    'TestCase',
+    'ModelViewTestCase',
+    'ViewTestCases',
+)
+
+
 class TestCase(_TestCase):
 class TestCase(_TestCase):
     user_permissions = ()
     user_permissions = ()
 
 
@@ -78,12 +84,15 @@ class TestCase(_TestCase):
             if api:
             if api:
 
 
                 # Replace ContentType numeric IDs with <app_label>.<model>
                 # Replace ContentType numeric IDs with <app_label>.<model>
-                if type(getattr(instance, key)) is ContentType:
+                field = instance._meta.get_field(key)
+                if type(field) is ForeignKey and field.related_model is ContentType:
                     ct = ContentType.objects.get(pk=value)
                     ct = ContentType.objects.get(pk=value)
                     model_dict[key] = f'{ct.app_label}.{ct.model}'
                     model_dict[key] = f'{ct.app_label}.{ct.model}'
+                elif type(field) is ManyToManyField and field.related_model is ContentType:
+                    model_dict[key] = [f'{ct.app_label}.{ct.model}' for ct in value]
 
 
                 # Convert IPNetwork instances to strings
                 # Convert IPNetwork instances to strings
-                if type(value) is IPNetwork:
+                elif type(value) is IPNetwork:
                     model_dict[key] = str(value)
                     model_dict[key] = str(value)
 
 
         # Omit any dictionary keys which are not instance attributes
         # Omit any dictionary keys which are not instance attributes
@@ -202,13 +211,6 @@ class ViewTestCases:
             # Try GET to non-permitted object
             # Try GET to non-permitted object
             self.assertHttpStatus(self.client.get(instance2.get_absolute_url()), 404)
             self.assertHttpStatus(self.client.get(instance2.get_absolute_url()), 404)
 
 
-        @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
-        def test_get_object_anonymous(self):
-            # Make the request as an unauthenticated user
-            self.client.logout()
-            response = self.client.get(self.model.objects.first().get_absolute_url())
-            self.assertHttpStatus(response, 200)
-
     class CreateObjectViewTestCase(ModelViewTestCase):
     class CreateObjectViewTestCase(ModelViewTestCase):
         """
         """
         Create a single new instance.
         Create a single new instance.
@@ -799,129 +801,3 @@ class ViewTestCases:
         TestCase suitable for testing device component models (ConsolePorts, Interfaces, etc.)
         TestCase suitable for testing device component models (ConsolePorts, Interfaces, etc.)
         """
         """
         maxDiff = None
         maxDiff = None
-
-
-#
-# REST API Tests
-#
-
-class APITestCase(TestCase):
-    client_class = APIClient
-    model = None
-
-    def setUp(self):
-        """
-        Create a superuser and token for API calls.
-        """
-        self.user = User.objects.create(username='testuser', is_superuser=True)
-        self.token = Token.objects.create(user=self.user)
-        self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
-
-    def _get_detail_url(self, instance):
-        viewname = f'{instance._meta.app_label}-api:{instance._meta.model_name}-detail'
-        return reverse(viewname, kwargs={'pk': instance.pk})
-
-    def _get_list_url(self):
-        viewname = f'{self.model._meta.app_label}-api:{self.model._meta.model_name}-list'
-        return reverse(viewname)
-
-
-class APIViewTestCases:
-
-    class GetObjectViewTestCase(APITestCase):
-
-        def test_get_object(self):
-            """
-            GET a single object identified by its numeric ID.
-            """
-            instance = self.model.objects.first()
-            url = self._get_detail_url(instance)
-            response = self.client.get(url, **self.header)
-
-            self.assertEqual(response.data['id'], instance.pk)
-
-    class ListObjectsViewTestCase(APITestCase):
-        brief_fields = []
-
-        def test_list_objects(self):
-            """
-            GET a list of objects.
-            """
-            url = self._get_list_url()
-            response = self.client.get(url, **self.header)
-
-            self.assertEqual(len(response.data['results']), self.model.objects.count())
-
-        def test_list_objects_brief(self):
-            """
-            GET a list of objects using the "brief" parameter.
-            """
-            url = f'{self._get_list_url()}?brief=1'
-            response = self.client.get(url, **self.header)
-
-            self.assertEqual(len(response.data['results']), self.model.objects.count())
-            self.assertEqual(sorted(response.data['results'][0]), self.brief_fields)
-
-    class CreateObjectViewTestCase(APITestCase):
-        create_data = []
-
-        def test_create_object(self):
-            """
-            POST a single object.
-            """
-            initial_count = self.model.objects.count()
-            url = self._get_list_url()
-            response = self.client.post(url, self.create_data[0], format='json', **self.header)
-
-            self.assertHttpStatus(response, status.HTTP_201_CREATED)
-            self.assertEqual(self.model.objects.count(), initial_count + 1)
-            self.assertInstanceEqual(self.model.objects.get(pk=response.data['id']), self.create_data[0], api=True)
-
-        def test_bulk_create_object(self):
-            """
-            POST a set of objects in a single request.
-            """
-            initial_count = self.model.objects.count()
-            url = self._get_list_url()
-            response = self.client.post(url, self.create_data, format='json', **self.header)
-
-            self.assertHttpStatus(response, status.HTTP_201_CREATED)
-            self.assertEqual(self.model.objects.count(), initial_count + len(self.create_data))
-
-    class UpdateObjectViewTestCase(APITestCase):
-        update_data = {}
-
-        def test_update_object(self):
-            """
-            PATCH a single object identified by its numeric ID.
-            """
-            instance = self.model.objects.first()
-            url = self._get_detail_url(instance)
-            update_data = self.update_data or getattr(self, 'create_data')[0]
-            response = self.client.patch(url, update_data, format='json', **self.header)
-
-            self.assertHttpStatus(response, status.HTTP_200_OK)
-            instance.refresh_from_db()
-            self.assertInstanceEqual(instance, self.update_data, api=True)
-
-    class DeleteObjectViewTestCase(APITestCase):
-
-        def test_delete_object(self):
-            """
-            DELETE a single object identified by its numeric ID.
-            """
-            instance = self.model.objects.first()
-            url = self._get_detail_url(instance)
-            response = self.client.delete(url, **self.header)
-
-            self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-            self.assertFalse(self.model.objects.filter(pk=instance.pk).exists())
-
-    class APIViewTestCase(
-        GetObjectViewTestCase,
-        ListObjectsViewTestCase,
-        CreateObjectViewTestCase,
-        UpdateObjectViewTestCase,
-        DeleteObjectViewTestCase
-    ):
-        pass

+ 12 - 19
netbox/utilities/tests/test_api.py

@@ -18,7 +18,6 @@ class WritableNestedSerializerTest(APITestCase):
     """
     """
 
 
     def setUp(self):
     def setUp(self):
-
         super().setUp()
         super().setUp()
 
 
         self.region_a = Region.objects.create(name='Region A', slug='region-a')
         self.region_a = Region.objects.create(name='Region A', slug='region-a')
@@ -26,39 +25,36 @@ class WritableNestedSerializerTest(APITestCase):
         self.site2 = Site.objects.create(region=self.region_a, name='Site 2', slug='site-2')
         self.site2 = Site.objects.create(region=self.region_a, name='Site 2', slug='site-2')
 
 
     def test_related_by_pk(self):
     def test_related_by_pk(self):
-
         data = {
         data = {
             'vid': 100,
             'vid': 100,
             'name': 'Test VLAN 100',
             'name': 'Test VLAN 100',
             'site': self.site1.pk,
             'site': self.site1.pk,
         }
         }
-
         url = reverse('ipam-api:vlan-list')
         url = reverse('ipam-api:vlan-list')
-        response = self.client.post(url, data, format='json', **self.header)
+        self.add_permissions('ipam.add_vlan')
 
 
+        response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(response.data['site']['id'], self.site1.pk)
         self.assertEqual(response.data['site']['id'], self.site1.pk)
         vlan = VLAN.objects.get(pk=response.data['id'])
         vlan = VLAN.objects.get(pk=response.data['id'])
         self.assertEqual(vlan.site, self.site1)
         self.assertEqual(vlan.site, self.site1)
 
 
     def test_related_by_pk_no_match(self):
     def test_related_by_pk_no_match(self):
-
         data = {
         data = {
             'vid': 100,
             'vid': 100,
             'name': 'Test VLAN 100',
             'name': 'Test VLAN 100',
             'site': 999,
             'site': 999,
         }
         }
-
         url = reverse('ipam-api:vlan-list')
         url = reverse('ipam-api:vlan-list')
+        self.add_permissions('ipam.add_vlan')
+
         with disable_warnings('django.request'):
         with disable_warnings('django.request'):
             response = self.client.post(url, data, format='json', **self.header)
             response = self.client.post(url, data, format='json', **self.header)
-
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(VLAN.objects.count(), 0)
         self.assertEqual(VLAN.objects.count(), 0)
         self.assertTrue(response.data['site'][0].startswith("Related object not found"))
         self.assertTrue(response.data['site'][0].startswith("Related object not found"))
 
 
     def test_related_by_attributes(self):
     def test_related_by_attributes(self):
-
         data = {
         data = {
             'vid': 100,
             'vid': 100,
             'name': 'Test VLAN 100',
             'name': 'Test VLAN 100',
@@ -66,17 +62,16 @@ class WritableNestedSerializerTest(APITestCase):
                 'name': 'Site 1'
                 'name': 'Site 1'
             },
             },
         }
         }
-
         url = reverse('ipam-api:vlan-list')
         url = reverse('ipam-api:vlan-list')
-        response = self.client.post(url, data, format='json', **self.header)
+        self.add_permissions('ipam.add_vlan')
 
 
+        response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(response.data['site']['id'], self.site1.pk)
         self.assertEqual(response.data['site']['id'], self.site1.pk)
         vlan = VLAN.objects.get(pk=response.data['id'])
         vlan = VLAN.objects.get(pk=response.data['id'])
         self.assertEqual(vlan.site, self.site1)
         self.assertEqual(vlan.site, self.site1)
 
 
     def test_related_by_attributes_no_match(self):
     def test_related_by_attributes_no_match(self):
-
         data = {
         data = {
             'vid': 100,
             'vid': 100,
             'name': 'Test VLAN 100',
             'name': 'Test VLAN 100',
@@ -84,17 +79,16 @@ class WritableNestedSerializerTest(APITestCase):
                 'name': 'Site X'
                 'name': 'Site X'
             },
             },
         }
         }
-
         url = reverse('ipam-api:vlan-list')
         url = reverse('ipam-api:vlan-list')
+        self.add_permissions('ipam.add_vlan')
+
         with disable_warnings('django.request'):
         with disable_warnings('django.request'):
             response = self.client.post(url, data, format='json', **self.header)
             response = self.client.post(url, data, format='json', **self.header)
-
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(VLAN.objects.count(), 0)
         self.assertEqual(VLAN.objects.count(), 0)
         self.assertTrue(response.data['site'][0].startswith("Related object not found"))
         self.assertTrue(response.data['site'][0].startswith("Related object not found"))
 
 
     def test_related_by_attributes_multiple_matches(self):
     def test_related_by_attributes_multiple_matches(self):
-
         data = {
         data = {
             'vid': 100,
             'vid': 100,
             'name': 'Test VLAN 100',
             'name': 'Test VLAN 100',
@@ -104,27 +98,26 @@ class WritableNestedSerializerTest(APITestCase):
                 },
                 },
             },
             },
         }
         }
-
         url = reverse('ipam-api:vlan-list')
         url = reverse('ipam-api:vlan-list')
+        self.add_permissions('ipam.add_vlan')
+
         with disable_warnings('django.request'):
         with disable_warnings('django.request'):
             response = self.client.post(url, data, format='json', **self.header)
             response = self.client.post(url, data, format='json', **self.header)
-
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(VLAN.objects.count(), 0)
         self.assertEqual(VLAN.objects.count(), 0)
         self.assertTrue(response.data['site'][0].startswith("Multiple objects match"))
         self.assertTrue(response.data['site'][0].startswith("Multiple objects match"))
 
 
     def test_related_by_invalid(self):
     def test_related_by_invalid(self):
-
         data = {
         data = {
             'vid': 100,
             'vid': 100,
             'name': 'Test VLAN 100',
             'name': 'Test VLAN 100',
             'site': 'XXX',
             'site': 'XXX',
         }
         }
-
         url = reverse('ipam-api:vlan-list')
         url = reverse('ipam-api:vlan-list')
+        self.add_permissions('ipam.add_vlan')
+
         with disable_warnings('django.request'):
         with disable_warnings('django.request'):
             response = self.client.post(url, data, format='json', **self.header)
             response = self.client.post(url, data, format='json', **self.header)
-
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(VLAN.objects.count(), 0)
         self.assertEqual(VLAN.objects.count(), 0)
 
 

+ 25 - 28
netbox/virtualization/tests/test_api.py

@@ -164,10 +164,10 @@ class VirtualMachineTest(APIViewTestCases.APIViewTestCase):
         Check that config context data is included by default in the virtual machines list.
         Check that config context data is included by default in the virtual machines list.
         """
         """
         virtualmachine = VirtualMachine.objects.first()
         virtualmachine = VirtualMachine.objects.first()
-        url = reverse('virtualization-api:virtualmachine-list')
-        url = '{}?id={}'.format(url, virtualmachine.pk)
-        response = self.client.get(url, **self.header)
+        url = '{}?id={}'.format(reverse('virtualization-api:virtualmachine-list'), virtualmachine.pk)
+        self.add_permissions('virtualization.view_virtualmachine')
 
 
+        response = self.client.get(url, **self.header)
         self.assertEqual(response.data['results'][0].get('config_context', {}).get('A'), 1)
         self.assertEqual(response.data['results'][0].get('config_context', {}).get('A'), 1)
 
 
     def test_config_context_excluded(self):
     def test_config_context_excluded(self):
@@ -175,8 +175,9 @@ class VirtualMachineTest(APIViewTestCases.APIViewTestCase):
         Check that config context data can be excluded by passing ?exclude=config_context.
         Check that config context data can be excluded by passing ?exclude=config_context.
         """
         """
         url = reverse('virtualization-api:virtualmachine-list') + '?exclude=config_context'
         url = reverse('virtualization-api:virtualmachine-list') + '?exclude=config_context'
-        response = self.client.get(url, **self.header)
+        self.add_permissions('virtualization.view_virtualmachine')
 
 
+        response = self.client.get(url, **self.header)
         self.assertFalse('config_context' in response.data['results'][0])
         self.assertFalse('config_context' in response.data['results'][0])
 
 
     def test_unique_name_per_cluster_constraint(self):
     def test_unique_name_per_cluster_constraint(self):
@@ -188,8 +189,9 @@ class VirtualMachineTest(APIViewTestCases.APIViewTestCase):
             'cluster': Cluster.objects.first().pk,
             'cluster': Cluster.objects.first().pk,
         }
         }
         url = reverse('virtualization-api:virtualmachine-list')
         url = reverse('virtualization-api:virtualmachine-list')
-        response = self.client.post(url, data, format='json', **self.header)
+        self.add_permissions('virtualization.add_virtualmachine')
 
 
+        response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
 
 
 
 
@@ -224,39 +226,38 @@ class InterfaceTest(APITestCase):
         self.vlan3 = VLAN.objects.create(name="Test VLAN 3", vid=3)
         self.vlan3 = VLAN.objects.create(name="Test VLAN 3", vid=3)
 
 
     def test_get_interface(self):
     def test_get_interface(self):
-
         url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk})
         url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk})
-        response = self.client.get(url, **self.header)
+        self.add_permissions('dcim.view_interface')
 
 
+        response = self.client.get(url, **self.header)
         self.assertEqual(response.data['name'], self.interface1.name)
         self.assertEqual(response.data['name'], self.interface1.name)
 
 
     def test_list_interfaces(self):
     def test_list_interfaces(self):
-
         url = reverse('virtualization-api:interface-list')
         url = reverse('virtualization-api:interface-list')
-        response = self.client.get(url, **self.header)
+        self.add_permissions('dcim.view_interface')
 
 
+        response = self.client.get(url, **self.header)
         self.assertEqual(response.data['count'], 3)
         self.assertEqual(response.data['count'], 3)
 
 
     def test_list_interfaces_brief(self):
     def test_list_interfaces_brief(self):
-
         url = reverse('virtualization-api:interface-list')
         url = reverse('virtualization-api:interface-list')
-        response = self.client.get('{}?brief=1'.format(url), **self.header)
+        self.add_permissions('dcim.view_interface')
 
 
+        response = self.client.get('{}?brief=1'.format(url), **self.header)
         self.assertEqual(
         self.assertEqual(
             sorted(response.data['results'][0]),
             sorted(response.data['results'][0]),
             ['id', 'name', 'url', 'virtual_machine']
             ['id', 'name', 'url', 'virtual_machine']
         )
         )
 
 
     def test_create_interface(self):
     def test_create_interface(self):
-
         data = {
         data = {
             'virtual_machine': self.virtualmachine.pk,
             'virtual_machine': self.virtualmachine.pk,
             'name': 'Test Interface 4',
             'name': 'Test Interface 4',
         }
         }
-
         url = reverse('virtualization-api:interface-list')
         url = reverse('virtualization-api:interface-list')
-        response = self.client.post(url, data, format='json', **self.header)
+        self.add_permissions('dcim.add_interface')
 
 
+        response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(Interface.objects.count(), 4)
         self.assertEqual(Interface.objects.count(), 4)
         interface4 = Interface.objects.get(pk=response.data['id'])
         interface4 = Interface.objects.get(pk=response.data['id'])
@@ -264,7 +265,6 @@ class InterfaceTest(APITestCase):
         self.assertEqual(interface4.name, data['name'])
         self.assertEqual(interface4.name, data['name'])
 
 
     def test_create_interface_with_802_1q(self):
     def test_create_interface_with_802_1q(self):
-
         data = {
         data = {
             'virtual_machine': self.virtualmachine.pk,
             'virtual_machine': self.virtualmachine.pk,
             'name': 'Test Interface 4',
             'name': 'Test Interface 4',
@@ -272,10 +272,10 @@ class InterfaceTest(APITestCase):
             'untagged_vlan': self.vlan3.id,
             'untagged_vlan': self.vlan3.id,
             'tagged_vlans': [self.vlan1.id, self.vlan2.id],
             'tagged_vlans': [self.vlan1.id, self.vlan2.id],
         }
         }
-
         url = reverse('virtualization-api:interface-list')
         url = reverse('virtualization-api:interface-list')
-        response = self.client.post(url, data, format='json', **self.header)
+        self.add_permissions('dcim.add_interface')
 
 
+        response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(Interface.objects.count(), 4)
         self.assertEqual(Interface.objects.count(), 4)
         self.assertEqual(response.data['virtual_machine']['id'], data['virtual_machine'])
         self.assertEqual(response.data['virtual_machine']['id'], data['virtual_machine'])
@@ -284,7 +284,6 @@ class InterfaceTest(APITestCase):
         self.assertEqual([v['id'] for v in response.data['tagged_vlans']], data['tagged_vlans'])
         self.assertEqual([v['id'] for v in response.data['tagged_vlans']], data['tagged_vlans'])
 
 
     def test_create_interface_bulk(self):
     def test_create_interface_bulk(self):
-
         data = [
         data = [
             {
             {
                 'virtual_machine': self.virtualmachine.pk,
                 'virtual_machine': self.virtualmachine.pk,
@@ -299,10 +298,10 @@ class InterfaceTest(APITestCase):
                 'name': 'Test Interface 6',
                 'name': 'Test Interface 6',
             },
             },
         ]
         ]
-
         url = reverse('virtualization-api:interface-list')
         url = reverse('virtualization-api:interface-list')
-        response = self.client.post(url, data, format='json', **self.header)
+        self.add_permissions('dcim.add_interface')
 
 
+        response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(Interface.objects.count(), 6)
         self.assertEqual(Interface.objects.count(), 6)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[0]['name'], data[0]['name'])
@@ -310,7 +309,6 @@ class InterfaceTest(APITestCase):
         self.assertEqual(response.data[2]['name'], data[2]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
 
 
     def test_create_interface_802_1q_bulk(self):
     def test_create_interface_802_1q_bulk(self):
-
         data = [
         data = [
             {
             {
                 'virtual_machine': self.virtualmachine.pk,
                 'virtual_machine': self.virtualmachine.pk,
@@ -334,10 +332,10 @@ class InterfaceTest(APITestCase):
                 'tagged_vlans': [self.vlan1.id],
                 'tagged_vlans': [self.vlan1.id],
             },
             },
         ]
         ]
-
         url = reverse('virtualization-api:interface-list')
         url = reverse('virtualization-api:interface-list')
-        response = self.client.post(url, data, format='json', **self.header)
+        self.add_permissions('dcim.add_interface')
 
 
+        response = self.client.post(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(Interface.objects.count(), 6)
         self.assertEqual(Interface.objects.count(), 6)
         for i in range(0, 3):
         for i in range(0, 3):
@@ -346,24 +344,23 @@ class InterfaceTest(APITestCase):
             self.assertEqual(response.data[i]['untagged_vlan']['id'], data[i]['untagged_vlan'])
             self.assertEqual(response.data[i]['untagged_vlan']['id'], data[i]['untagged_vlan'])
 
 
     def test_update_interface(self):
     def test_update_interface(self):
-
         data = {
         data = {
             'virtual_machine': self.virtualmachine.pk,
             'virtual_machine': self.virtualmachine.pk,
             'name': 'Test Interface X',
             'name': 'Test Interface X',
         }
         }
-
         url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk})
         url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk})
-        response = self.client.put(url, data, format='json', **self.header)
+        self.add_permissions('dcim.change_interface')
 
 
+        response = self.client.put(url, data, format='json', **self.header)
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertEqual(Interface.objects.count(), 3)
         self.assertEqual(Interface.objects.count(), 3)
         interface1 = Interface.objects.get(pk=response.data['id'])
         interface1 = Interface.objects.get(pk=response.data['id'])
         self.assertEqual(interface1.name, data['name'])
         self.assertEqual(interface1.name, data['name'])
 
 
     def test_delete_interface(self):
     def test_delete_interface(self):
-
         url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk})
         url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk})
-        response = self.client.delete(url, **self.header)
+        self.add_permissions('dcim.delete_interface')
 
 
+        response = self.client.delete(url, **self.header)
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
         self.assertEqual(Interface.objects.count(), 2)
         self.assertEqual(Interface.objects.count(), 2)