Просмотр исходного кода

Move TreeNodeMultipleChoiceFilter tests to utilities (follow-up to #3616)

Jeremy Stretch 6 лет назад
Родитель
Сommit
875e09013c

+ 11 - 57
netbox/dcim/tests/test_api.py

@@ -127,9 +127,6 @@ class SiteTest(APITestCase):
         self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2')
         self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='test-site-1')
         self.site2 = Site.objects.create(region=self.region1, name='Test Site 2', slug='test-site-2')
-        self.site3 = Site.objects.create(region=self.region2, name='Test Site 3', slug='test-site-3')
-        self.site_non_region1 = Site.objects.create(name='Test Site Null Region1', slug='test-site-no-region1')
-        self.site_non_region2 = Site.objects.create(name='Test Site Null Region2', slug='test-site-no-region2')
 
     def test_get_site(self):
 
@@ -164,7 +161,7 @@ class SiteTest(APITestCase):
         url = reverse('dcim-api:site-list')
         response = self.client.get(url, **self.header)
 
-        self.assertEqual(response.data['count'], 5)
+        self.assertEqual(response.data['count'], 3)
 
     def test_list_sites_brief(self):
 
@@ -176,20 +173,6 @@ class SiteTest(APITestCase):
             ['id', 'name', 'slug', 'url']
         )
 
-    def test_list_sites_null_region(self):
-
-        url = reverse('dcim-api:site-list')
-        response = self.client.get('{}?region=null'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 2)
-
-    def test_list_sites_multiple_regions(self):
-
-        url = reverse('dcim-api:site-list')
-        response = self.client.get('{}?region=null&region=test-region-1'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 4)
-
     def test_create_site(self):
 
         data = {
@@ -203,7 +186,7 @@ class SiteTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Site.objects.count(), 6)
+        self.assertEqual(Site.objects.count(), 4)
         site4 = Site.objects.get(pk=response.data['id'])
         self.assertEqual(site4.name, data['name'])
         self.assertEqual(site4.slug, data['slug'])
@@ -236,7 +219,7 @@ class SiteTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Site.objects.count(), 8)
+        self.assertEqual(Site.objects.count(), 6)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
@@ -253,7 +236,7 @@ class SiteTest(APITestCase):
         response = self.client.put(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(Site.objects.count(), 5)
+        self.assertEqual(Site.objects.count(), 3)
         site1 = Site.objects.get(pk=response.data['id'])
         self.assertEqual(site1.name, data['name'])
         self.assertEqual(site1.slug, data['slug'])
@@ -265,7 +248,7 @@ class SiteTest(APITestCase):
         response = self.client.delete(url, **self.header)
 
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(Site.objects.count(), 4)
+        self.assertEqual(Site.objects.count(), 2)
 
 
 class RackGroupTest(APITestCase):
@@ -1769,8 +1752,7 @@ class DeviceTest(APITestCase):
 
         super().setUp()
 
-        region = Region.objects.create(name='Test Region', slug='test-region-1')
-        self.site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1')
+        self.site1 = Site.objects.create(name='Test Site 1', slug='test-site-1')
         self.site2 = Site.objects.create(name='Test Site 2', slug='test-site-2')
         manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1')
         self.devicetype1 = DeviceType.objects.create(
@@ -1818,20 +1800,6 @@ class DeviceTest(APITestCase):
                 'B': 2
             }
         )
-        self.device_non_region1 = Device.objects.create(
-            device_type=self.devicetype1,
-            device_role=self.devicerole1,
-            name='Test Device Null Region1',
-            site=self.site2,
-            cluster=self.cluster1
-        )
-        self.device_non_region2 = Device.objects.create(
-            device_type=self.devicetype1,
-            device_role=self.devicerole1,
-            name='Test Device Null Region2',
-            site=self.site2,
-            cluster=self.cluster1
-        )
 
     def test_get_device(self):
 
@@ -1847,7 +1815,7 @@ class DeviceTest(APITestCase):
         url = reverse('dcim-api:device-list')
         response = self.client.get(url, **self.header)
 
-        self.assertEqual(response.data['count'], 6)
+        self.assertEqual(response.data['count'], 4)
 
     def test_list_devices_brief(self):
 
@@ -1859,20 +1827,6 @@ class DeviceTest(APITestCase):
             ['display_name', 'id', 'name', 'url']
         )
 
-    def test_list_devices_null_region(self):
-
-        url = reverse('dcim-api:device-list')
-        response = self.client.get('{}?region=null'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 2)
-
-    def test_list_devices_multiple_regions(self):
-
-        url = reverse('dcim-api:device-list')
-        response = self.client.get('{}?region=null&region=test-region-1'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 6)
-
     def test_create_device(self):
 
         data = {
@@ -1887,7 +1841,7 @@ class DeviceTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Device.objects.count(), 7)
+        self.assertEqual(Device.objects.count(), 5)
         device4 = Device.objects.get(pk=response.data['id'])
         self.assertEqual(device4.device_type_id, data['device_type'])
         self.assertEqual(device4.device_role_id, data['device_role'])
@@ -1922,7 +1876,7 @@ class DeviceTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Device.objects.count(), 9)
+        self.assertEqual(Device.objects.count(), 7)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
@@ -1946,7 +1900,7 @@ class DeviceTest(APITestCase):
         response = self.client.put(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(Device.objects.count(), 6)
+        self.assertEqual(Device.objects.count(), 4)
         device1 = Device.objects.get(pk=response.data['id'])
         self.assertEqual(device1.device_type_id, data['device_type'])
         self.assertEqual(device1.device_role_id, data['device_role'])
@@ -1961,7 +1915,7 @@ class DeviceTest(APITestCase):
         response = self.client.delete(url, **self.header)
 
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(Device.objects.count(), 5)
+        self.assertEqual(Device.objects.count(), 3)
 
     def test_config_context_included_by_default_in_list_view(self):
 

+ 62 - 0
netbox/utilities/tests/test_filters.py

@@ -0,0 +1,62 @@
+from django.conf import settings
+from django.test import TestCase
+import django_filters
+
+from dcim.models import Region, Site
+from utilities.filters import TreeNodeMultipleChoiceFilter
+
+
+class TreeNodeMultipleChoiceFilterTest(TestCase):
+
+    class SiteFilterSet(django_filters.FilterSet):
+        region = TreeNodeMultipleChoiceFilter(
+            queryset=Region.objects.all(),
+            field_name='region__in',
+            to_field_name='slug',
+        )
+
+    def setUp(self):
+
+        super().setUp()
+
+        self.region1 = Region.objects.create(name='Test Region 1', slug='test-region-1')
+        self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2')
+        self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='test-site1')
+        self.site2 = Site.objects.create(region=self.region2, name='Test Site 2', slug='test-site2')
+        self.site3 = Site.objects.create(region=None, name='Test Site 3', slug='test-site3')
+
+        self.queryset = Site.objects.all()
+
+    def test_filter_single(self):
+
+        kwargs = {'region': ['test-region-1']}
+        qs = self.SiteFilterSet(kwargs, self.queryset).qs
+
+        self.assertEqual(qs.count(), 1)
+        self.assertEqual(qs[0], self.site1)
+
+    def test_filter_multiple(self):
+
+        kwargs = {'region': ['test-region-1', 'test-region-2']}
+        qs = self.SiteFilterSet(kwargs, self.queryset).qs
+
+        self.assertEqual(qs.count(), 2)
+        self.assertEqual(qs[0], self.site1)
+        self.assertEqual(qs[1], self.site2)
+
+    def test_filter_null(self):
+
+        kwargs = {'region': [settings.FILTERS_NULL_CHOICE_VALUE]}
+        qs = self.SiteFilterSet(kwargs, self.queryset).qs
+
+        self.assertEqual(qs.count(), 1)
+        self.assertEqual(qs[0], self.site3)
+
+    def test_filter_combined(self):
+
+        kwargs = {'region': ['test-region-1', settings.FILTERS_NULL_CHOICE_VALUE]}
+        qs = self.SiteFilterSet(kwargs, self.queryset).qs
+
+        self.assertEqual(qs.count(), 2)
+        self.assertEqual(qs[0], self.site1)
+        self.assertEqual(qs[1], self.site3)

+ 14 - 30
netbox/virtualization/tests/test_api.py

@@ -3,7 +3,7 @@ from netaddr import IPNetwork
 from rest_framework import status
 
 from dcim.constants import IFACE_TYPE_VIRTUAL, IFACE_MODE_TAGGED
-from dcim.models import Interface, Region, Site
+from dcim.models import Interface
 from ipam.models import IPAddress, VLAN
 from utilities.testing import APITestCase
 from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine
@@ -330,14 +330,9 @@ class VirtualMachineTest(APITestCase):
 
         super().setUp()
 
-        region = Region.objects.create(name='Test Region 1', slug='test-region-1')
-        site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1')
-        site2 = Site.objects.create(name='Test Site 2', slug='test-site-2')
-
         cluster_type = ClusterType.objects.create(name='Test Cluster Type 1', slug='test-cluster-type-1')
         cluster_group = ClusterGroup.objects.create(name='Test Cluster Group 1', slug='test-cluster-group-1')
-        self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group, site=site1)
-        self.cluster2 = Cluster.objects.create(name='Test Cluster 2', type=cluster_type, group=cluster_group, site=site2)
+        self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group)
 
         self.virtualmachine1 = VirtualMachine.objects.create(name='Test Virtual Machine 1', cluster=self.cluster1)
         self.virtualmachine2 = VirtualMachine.objects.create(name='Test Virtual Machine 2', cluster=self.cluster1)
@@ -350,8 +345,6 @@ class VirtualMachineTest(APITestCase):
                 'B': 2
             }
         )
-        self.virtualmachine_non_region1 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region1', cluster=self.cluster2)
-        self.virtualmachine_non_region2 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region2', cluster=self.cluster2)
 
     def test_get_virtualmachine(self):
 
@@ -365,7 +358,7 @@ class VirtualMachineTest(APITestCase):
         url = reverse('virtualization-api:virtualmachine-list')
         response = self.client.get(url, **self.header)
 
-        self.assertEqual(response.data['count'], 6)
+        self.assertEqual(response.data['count'], 4)
 
     def test_list_virtualmachines_brief(self):
 
@@ -377,20 +370,6 @@ class VirtualMachineTest(APITestCase):
             ['id', 'name', 'url']
         )
 
-    def test_list_virtualmachines_null_region(self):
-
-        url = reverse('virtualization-api:virtualmachine-list')
-        response = self.client.get('{}?region=null'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 2)
-
-    def test_list_virtualmachines_multiple_regions(self):
-
-        url = reverse('virtualization-api:virtualmachine-list')
-        response = self.client.get('{}?region=null&region=test-region-1'.format(url), **self.header)
-
-        self.assertEqual(response.data['count'], 6)
-
     def test_create_virtualmachine(self):
 
         data = {
@@ -402,7 +381,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(VirtualMachine.objects.count(), 7)
+        self.assertEqual(VirtualMachine.objects.count(), 5)
         virtualmachine4 = VirtualMachine.objects.get(pk=response.data['id'])
         self.assertEqual(virtualmachine4.name, data['name'])
         self.assertEqual(virtualmachine4.cluster.pk, data['cluster'])
@@ -417,7 +396,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
-        self.assertEqual(VirtualMachine.objects.count(), 6)
+        self.assertEqual(VirtualMachine.objects.count(), 4)
 
     def test_create_virtualmachine_bulk(self):
 
@@ -440,7 +419,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(VirtualMachine.objects.count(), 9)
+        self.assertEqual(VirtualMachine.objects.count(), 7)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
@@ -451,9 +430,14 @@ class VirtualMachineTest(APITestCase):
         ip4_address = IPAddress.objects.create(address=IPNetwork('192.0.2.1/24'), interface=interface)
         ip6_address = IPAddress.objects.create(address=IPNetwork('2001:db8::1/64'), interface=interface)
 
+        cluster2 = Cluster.objects.create(
+            name='Test Cluster 2',
+            type=ClusterType.objects.first(),
+            group=ClusterGroup.objects.first()
+        )
         data = {
             'name': 'Test Virtual Machine X',
-            'cluster': self.cluster2.pk,
+            'cluster': cluster2.pk,
             'primary_ip4': ip4_address.pk,
             'primary_ip6': ip6_address.pk,
         }
@@ -462,7 +446,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.put(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(VirtualMachine.objects.count(), 6)
+        self.assertEqual(VirtualMachine.objects.count(), 4)
         virtualmachine1 = VirtualMachine.objects.get(pk=response.data['id'])
         self.assertEqual(virtualmachine1.name, data['name'])
         self.assertEqual(virtualmachine1.cluster.pk, data['cluster'])
@@ -475,7 +459,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.delete(url, **self.header)
 
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(VirtualMachine.objects.count(), 5)
+        self.assertEqual(VirtualMachine.objects.count(), 3)
 
     def test_config_context_included_by_default_in_list_view(self):