Parcourir la source

Move TreeNodeMultipleChoiceFilter tests to utilities (follow-up to #3616)

Jeremy Stretch il y a 6 ans
Parent
commit
875e09013c

+ 11 - 57
netbox/dcim/tests/test_api.py

@@ -127,9 +127,6 @@ class SiteTest(APITestCase):
         self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2')
         self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2')
         self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='test-site-1')
         self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='test-site-1')
         self.site2 = Site.objects.create(region=self.region1, name='Test Site 2', slug='test-site-2')
         self.site2 = Site.objects.create(region=self.region1, name='Test Site 2', slug='test-site-2')
-        self.site3 = Site.objects.create(region=self.region2, name='Test Site 3', slug='test-site-3')
-        self.site_non_region1 = Site.objects.create(name='Test Site Null Region1', slug='test-site-no-region1')
-        self.site_non_region2 = Site.objects.create(name='Test Site Null Region2', slug='test-site-no-region2')
 
 
     def test_get_site(self):
     def test_get_site(self):
 
 
@@ -164,7 +161,7 @@ class SiteTest(APITestCase):
         url = reverse('dcim-api:site-list')
         url = reverse('dcim-api:site-list')
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
-        self.assertEqual(response.data['count'], 5)
+        self.assertEqual(response.data['count'], 3)
 
 
     def test_list_sites_brief(self):
     def test_list_sites_brief(self):
 
 
@@ -176,20 +173,6 @@ class SiteTest(APITestCase):
             ['id', 'name', 'slug', 'url']
             ['id', 'name', 'slug', 'url']
         )
         )
 
 
-    def test_list_sites_null_region(self):
-
-        url = reverse('dcim-api:site-list')
-        response = self.client.get('{}?region=null'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 2)
-
-    def test_list_sites_multiple_regions(self):
-
-        url = reverse('dcim-api:site-list')
-        response = self.client.get('{}?region=null&region=test-region-1'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 4)
-
     def test_create_site(self):
     def test_create_site(self):
 
 
         data = {
         data = {
@@ -203,7 +186,7 @@ class SiteTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Site.objects.count(), 6)
+        self.assertEqual(Site.objects.count(), 4)
         site4 = Site.objects.get(pk=response.data['id'])
         site4 = Site.objects.get(pk=response.data['id'])
         self.assertEqual(site4.name, data['name'])
         self.assertEqual(site4.name, data['name'])
         self.assertEqual(site4.slug, data['slug'])
         self.assertEqual(site4.slug, data['slug'])
@@ -236,7 +219,7 @@ class SiteTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Site.objects.count(), 8)
+        self.assertEqual(Site.objects.count(), 6)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
@@ -253,7 +236,7 @@ class SiteTest(APITestCase):
         response = self.client.put(url, data, format='json', **self.header)
         response = self.client.put(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(Site.objects.count(), 5)
+        self.assertEqual(Site.objects.count(), 3)
         site1 = Site.objects.get(pk=response.data['id'])
         site1 = Site.objects.get(pk=response.data['id'])
         self.assertEqual(site1.name, data['name'])
         self.assertEqual(site1.name, data['name'])
         self.assertEqual(site1.slug, data['slug'])
         self.assertEqual(site1.slug, data['slug'])
@@ -265,7 +248,7 @@ class SiteTest(APITestCase):
         response = self.client.delete(url, **self.header)
         response = self.client.delete(url, **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(Site.objects.count(), 4)
+        self.assertEqual(Site.objects.count(), 2)
 
 
 
 
 class RackGroupTest(APITestCase):
 class RackGroupTest(APITestCase):
@@ -1769,8 +1752,7 @@ class DeviceTest(APITestCase):
 
 
         super().setUp()
         super().setUp()
 
 
-        region = Region.objects.create(name='Test Region', slug='test-region-1')
-        self.site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1')
+        self.site1 = Site.objects.create(name='Test Site 1', slug='test-site-1')
         self.site2 = Site.objects.create(name='Test Site 2', slug='test-site-2')
         self.site2 = Site.objects.create(name='Test Site 2', slug='test-site-2')
         manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1')
         manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1')
         self.devicetype1 = DeviceType.objects.create(
         self.devicetype1 = DeviceType.objects.create(
@@ -1818,20 +1800,6 @@ class DeviceTest(APITestCase):
                 'B': 2
                 'B': 2
             }
             }
         )
         )
-        self.device_non_region1 = Device.objects.create(
-            device_type=self.devicetype1,
-            device_role=self.devicerole1,
-            name='Test Device Null Region1',
-            site=self.site2,
-            cluster=self.cluster1
-        )
-        self.device_non_region2 = Device.objects.create(
-            device_type=self.devicetype1,
-            device_role=self.devicerole1,
-            name='Test Device Null Region2',
-            site=self.site2,
-            cluster=self.cluster1
-        )
 
 
     def test_get_device(self):
     def test_get_device(self):
 
 
@@ -1847,7 +1815,7 @@ class DeviceTest(APITestCase):
         url = reverse('dcim-api:device-list')
         url = reverse('dcim-api:device-list')
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
-        self.assertEqual(response.data['count'], 6)
+        self.assertEqual(response.data['count'], 4)
 
 
     def test_list_devices_brief(self):
     def test_list_devices_brief(self):
 
 
@@ -1859,20 +1827,6 @@ class DeviceTest(APITestCase):
             ['display_name', 'id', 'name', 'url']
             ['display_name', 'id', 'name', 'url']
         )
         )
 
 
-    def test_list_devices_null_region(self):
-
-        url = reverse('dcim-api:device-list')
-        response = self.client.get('{}?region=null'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 2)
-
-    def test_list_devices_multiple_regions(self):
-
-        url = reverse('dcim-api:device-list')
-        response = self.client.get('{}?region=null&region=test-region-1'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 6)
-
     def test_create_device(self):
     def test_create_device(self):
 
 
         data = {
         data = {
@@ -1887,7 +1841,7 @@ class DeviceTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Device.objects.count(), 7)
+        self.assertEqual(Device.objects.count(), 5)
         device4 = Device.objects.get(pk=response.data['id'])
         device4 = Device.objects.get(pk=response.data['id'])
         self.assertEqual(device4.device_type_id, data['device_type'])
         self.assertEqual(device4.device_type_id, data['device_type'])
         self.assertEqual(device4.device_role_id, data['device_role'])
         self.assertEqual(device4.device_role_id, data['device_role'])
@@ -1922,7 +1876,7 @@ class DeviceTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Device.objects.count(), 9)
+        self.assertEqual(Device.objects.count(), 7)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
@@ -1946,7 +1900,7 @@ class DeviceTest(APITestCase):
         response = self.client.put(url, data, format='json', **self.header)
         response = self.client.put(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(Device.objects.count(), 6)
+        self.assertEqual(Device.objects.count(), 4)
         device1 = Device.objects.get(pk=response.data['id'])
         device1 = Device.objects.get(pk=response.data['id'])
         self.assertEqual(device1.device_type_id, data['device_type'])
         self.assertEqual(device1.device_type_id, data['device_type'])
         self.assertEqual(device1.device_role_id, data['device_role'])
         self.assertEqual(device1.device_role_id, data['device_role'])
@@ -1961,7 +1915,7 @@ class DeviceTest(APITestCase):
         response = self.client.delete(url, **self.header)
         response = self.client.delete(url, **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(Device.objects.count(), 5)
+        self.assertEqual(Device.objects.count(), 3)
 
 
     def test_config_context_included_by_default_in_list_view(self):
     def test_config_context_included_by_default_in_list_view(self):
 
 

+ 62 - 0
netbox/utilities/tests/test_filters.py

@@ -0,0 +1,62 @@
+from django.conf import settings
+from django.test import TestCase
+import django_filters
+
+from dcim.models import Region, Site
+from utilities.filters import TreeNodeMultipleChoiceFilter
+
+
+class TreeNodeMultipleChoiceFilterTest(TestCase):
+
+    class SiteFilterSet(django_filters.FilterSet):
+        region = TreeNodeMultipleChoiceFilter(
+            queryset=Region.objects.all(),
+            field_name='region__in',
+            to_field_name='slug',
+        )
+
+    def setUp(self):
+
+        super().setUp()
+
+        self.region1 = Region.objects.create(name='Test Region 1', slug='test-region-1')
+        self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2')
+        self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='test-site1')
+        self.site2 = Site.objects.create(region=self.region2, name='Test Site 2', slug='test-site2')
+        self.site3 = Site.objects.create(region=None, name='Test Site 3', slug='test-site3')
+
+        self.queryset = Site.objects.all()
+
+    def test_filter_single(self):
+
+        kwargs = {'region': ['test-region-1']}
+        qs = self.SiteFilterSet(kwargs, self.queryset).qs
+
+        self.assertEqual(qs.count(), 1)
+        self.assertEqual(qs[0], self.site1)
+
+    def test_filter_multiple(self):
+
+        kwargs = {'region': ['test-region-1', 'test-region-2']}
+        qs = self.SiteFilterSet(kwargs, self.queryset).qs
+
+        self.assertEqual(qs.count(), 2)
+        self.assertEqual(qs[0], self.site1)
+        self.assertEqual(qs[1], self.site2)
+
+    def test_filter_null(self):
+
+        kwargs = {'region': [settings.FILTERS_NULL_CHOICE_VALUE]}
+        qs = self.SiteFilterSet(kwargs, self.queryset).qs
+
+        self.assertEqual(qs.count(), 1)
+        self.assertEqual(qs[0], self.site3)
+
+    def test_filter_combined(self):
+
+        kwargs = {'region': ['test-region-1', settings.FILTERS_NULL_CHOICE_VALUE]}
+        qs = self.SiteFilterSet(kwargs, self.queryset).qs
+
+        self.assertEqual(qs.count(), 2)
+        self.assertEqual(qs[0], self.site1)
+        self.assertEqual(qs[1], self.site3)

+ 14 - 30
netbox/virtualization/tests/test_api.py

@@ -3,7 +3,7 @@ from netaddr import IPNetwork
 from rest_framework import status
 from rest_framework import status
 
 
 from dcim.constants import IFACE_TYPE_VIRTUAL, IFACE_MODE_TAGGED
 from dcim.constants import IFACE_TYPE_VIRTUAL, IFACE_MODE_TAGGED
-from dcim.models import Interface, Region, Site
+from dcim.models import Interface
 from ipam.models import IPAddress, VLAN
 from ipam.models import IPAddress, VLAN
 from utilities.testing import APITestCase
 from utilities.testing import APITestCase
 from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine
 from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine
@@ -330,14 +330,9 @@ class VirtualMachineTest(APITestCase):
 
 
         super().setUp()
         super().setUp()
 
 
-        region = Region.objects.create(name='Test Region 1', slug='test-region-1')
-        site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1')
-        site2 = Site.objects.create(name='Test Site 2', slug='test-site-2')
-
         cluster_type = ClusterType.objects.create(name='Test Cluster Type 1', slug='test-cluster-type-1')
         cluster_type = ClusterType.objects.create(name='Test Cluster Type 1', slug='test-cluster-type-1')
         cluster_group = ClusterGroup.objects.create(name='Test Cluster Group 1', slug='test-cluster-group-1')
         cluster_group = ClusterGroup.objects.create(name='Test Cluster Group 1', slug='test-cluster-group-1')
-        self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group, site=site1)
-        self.cluster2 = Cluster.objects.create(name='Test Cluster 2', type=cluster_type, group=cluster_group, site=site2)
+        self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group)
 
 
         self.virtualmachine1 = VirtualMachine.objects.create(name='Test Virtual Machine 1', cluster=self.cluster1)
         self.virtualmachine1 = VirtualMachine.objects.create(name='Test Virtual Machine 1', cluster=self.cluster1)
         self.virtualmachine2 = VirtualMachine.objects.create(name='Test Virtual Machine 2', cluster=self.cluster1)
         self.virtualmachine2 = VirtualMachine.objects.create(name='Test Virtual Machine 2', cluster=self.cluster1)
@@ -350,8 +345,6 @@ class VirtualMachineTest(APITestCase):
                 'B': 2
                 'B': 2
             }
             }
         )
         )
-        self.virtualmachine_non_region1 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region1', cluster=self.cluster2)
-        self.virtualmachine_non_region2 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region2', cluster=self.cluster2)
 
 
     def test_get_virtualmachine(self):
     def test_get_virtualmachine(self):
 
 
@@ -365,7 +358,7 @@ class VirtualMachineTest(APITestCase):
         url = reverse('virtualization-api:virtualmachine-list')
         url = reverse('virtualization-api:virtualmachine-list')
         response = self.client.get(url, **self.header)
         response = self.client.get(url, **self.header)
 
 
-        self.assertEqual(response.data['count'], 6)
+        self.assertEqual(response.data['count'], 4)
 
 
     def test_list_virtualmachines_brief(self):
     def test_list_virtualmachines_brief(self):
 
 
@@ -377,20 +370,6 @@ class VirtualMachineTest(APITestCase):
             ['id', 'name', 'url']
             ['id', 'name', 'url']
         )
         )
 
 
-    def test_list_virtualmachines_null_region(self):
-
-        url = reverse('virtualization-api:virtualmachine-list')
-        response = self.client.get('{}?region=null'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 2)
-
-    def test_list_virtualmachines_multiple_regions(self):
-
-        url = reverse('virtualization-api:virtualmachine-list')
-        response = self.client.get('{}?region=null&region=test-region-1'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 6)
-
     def test_create_virtualmachine(self):
     def test_create_virtualmachine(self):
 
 
         data = {
         data = {
@@ -402,7 +381,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(VirtualMachine.objects.count(), 7)
+        self.assertEqual(VirtualMachine.objects.count(), 5)
         virtualmachine4 = VirtualMachine.objects.get(pk=response.data['id'])
         virtualmachine4 = VirtualMachine.objects.get(pk=response.data['id'])
         self.assertEqual(virtualmachine4.name, data['name'])
         self.assertEqual(virtualmachine4.name, data['name'])
         self.assertEqual(virtualmachine4.cluster.pk, data['cluster'])
         self.assertEqual(virtualmachine4.cluster.pk, data['cluster'])
@@ -417,7 +396,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
-        self.assertEqual(VirtualMachine.objects.count(), 6)
+        self.assertEqual(VirtualMachine.objects.count(), 4)
 
 
     def test_create_virtualmachine_bulk(self):
     def test_create_virtualmachine_bulk(self):
 
 
@@ -440,7 +419,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(VirtualMachine.objects.count(), 9)
+        self.assertEqual(VirtualMachine.objects.count(), 7)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
@@ -451,9 +430,14 @@ class VirtualMachineTest(APITestCase):
         ip4_address = IPAddress.objects.create(address=IPNetwork('192.0.2.1/24'), interface=interface)
         ip4_address = IPAddress.objects.create(address=IPNetwork('192.0.2.1/24'), interface=interface)
         ip6_address = IPAddress.objects.create(address=IPNetwork('2001:db8::1/64'), interface=interface)
         ip6_address = IPAddress.objects.create(address=IPNetwork('2001:db8::1/64'), interface=interface)
 
 
+        cluster2 = Cluster.objects.create(
+            name='Test Cluster 2',
+            type=ClusterType.objects.first(),
+            group=ClusterGroup.objects.first()
+        )
         data = {
         data = {
             'name': 'Test Virtual Machine X',
             'name': 'Test Virtual Machine X',
-            'cluster': self.cluster2.pk,
+            'cluster': cluster2.pk,
             'primary_ip4': ip4_address.pk,
             'primary_ip4': ip4_address.pk,
             'primary_ip6': ip6_address.pk,
             'primary_ip6': ip6_address.pk,
         }
         }
@@ -462,7 +446,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.put(url, data, format='json', **self.header)
         response = self.client.put(url, data, format='json', **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(VirtualMachine.objects.count(), 6)
+        self.assertEqual(VirtualMachine.objects.count(), 4)
         virtualmachine1 = VirtualMachine.objects.get(pk=response.data['id'])
         virtualmachine1 = VirtualMachine.objects.get(pk=response.data['id'])
         self.assertEqual(virtualmachine1.name, data['name'])
         self.assertEqual(virtualmachine1.name, data['name'])
         self.assertEqual(virtualmachine1.cluster.pk, data['cluster'])
         self.assertEqual(virtualmachine1.cluster.pk, data['cluster'])
@@ -475,7 +459,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.delete(url, **self.header)
         response = self.client.delete(url, **self.header)
 
 
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(VirtualMachine.objects.count(), 5)
+        self.assertEqual(VirtualMachine.objects.count(), 3)
 
 
     def test_config_context_included_by_default_in_list_view(self):
     def test_config_context_included_by_default_in_list_view(self):