Просмотр исходного кода

filtering multiple regions with null

kobayashi 6 лет назад
Родитель
Сommit
d2aa9b8e79
3 измененных файлов с 67 добавлено и 40 удалено
  1. 44 25
      netbox/dcim/tests/test_api.py
  2. 7 6
      netbox/utilities/filters.py
  3. 16 9
      netbox/virtualization/tests/test_api.py

+ 44 - 25
netbox/dcim/tests/test_api.py

@@ -127,7 +127,9 @@ class SiteTest(APITestCase):
         self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2')
         self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='test-site-1')
         self.site2 = Site.objects.create(region=self.region1, name='Test Site 2', slug='test-site-2')
-        self.site3 = Site.objects.create(region=self.region1, name='Test Site 3', slug='test-site-3')
+        self.site3 = Site.objects.create(region=self.region2, name='Test Site 3', slug='test-site-3')
+        self.site_non_region1 = Site.objects.create(name='Test Site Null Region1', slug='test-site-no-region1')
+        self.site_non_region2 = Site.objects.create(name='Test Site Null Region2', slug='test-site-no-region2')
 
     def test_get_site(self):
 
@@ -162,7 +164,7 @@ class SiteTest(APITestCase):
         url = reverse('dcim-api:site-list')
         response = self.client.get(url, **self.header)
 
-        self.assertEqual(response.data['count'], 3)
+        self.assertEqual(response.data['count'], 5)
 
     def test_list_sites_brief(self):
 
@@ -176,14 +178,18 @@ class SiteTest(APITestCase):
 
     def test_list_sites_null_region(self):
 
-        Site.objects.create(name='Test Site Null Region1', slug='test-site-no-region1')
-        Site.objects.create(name='Test Site Null Region2', slug='test-site-no-region2')
-
         url = reverse('dcim-api:site-list')
         response = self.client.get('{}?region=null'.format(url), **self.header)
 
         self.assertEqual(response.data['count'], 2)
 
+    def test_list_sites_multiple_regions(self):
+
+        url = reverse('dcim-api:site-list')
+        response = self.client.get('{}?region=null&region=test-region-1'.format(url), **self.header)
+
+        self.assertEqual(response.data['count'], 4)
+
     def test_create_site(self):
 
         data = {
@@ -197,7 +203,7 @@ class SiteTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Site.objects.count(), 4)
+        self.assertEqual(Site.objects.count(), 6)
         site4 = Site.objects.get(pk=response.data['id'])
         self.assertEqual(site4.name, data['name'])
         self.assertEqual(site4.slug, data['slug'])
@@ -230,7 +236,7 @@ class SiteTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Site.objects.count(), 6)
+        self.assertEqual(Site.objects.count(), 8)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
@@ -247,7 +253,7 @@ class SiteTest(APITestCase):
         response = self.client.put(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(Site.objects.count(), 3)
+        self.assertEqual(Site.objects.count(), 5)
         site1 = Site.objects.get(pk=response.data['id'])
         self.assertEqual(site1.name, data['name'])
         self.assertEqual(site1.slug, data['slug'])
@@ -259,7 +265,7 @@ class SiteTest(APITestCase):
         response = self.client.delete(url, **self.header)
 
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(Site.objects.count(), 2)
+        self.assertEqual(Site.objects.count(), 4)
 
 
 class RackGroupTest(APITestCase):
@@ -1763,7 +1769,7 @@ class DeviceTest(APITestCase):
 
         super().setUp()
 
-        region = Region.objects.create(name='Test Region', slug='test-region')
+        region = Region.objects.create(name='Test Region', slug='test-region-1')
         self.site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1')
         self.site2 = Site.objects.create(name='Test Site 2', slug='test-site-2')
         manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1')
@@ -1812,6 +1818,20 @@ class DeviceTest(APITestCase):
                 'B': 2
             }
         )
+        self.device_non_region1 = Device.objects.create(
+            device_type=self.devicetype1,
+            device_role=self.devicerole1,
+            name='Test Device Null Region1',
+            site=self.site2,
+            cluster=self.cluster1
+        )
+        self.device_non_region2 = Device.objects.create(
+            device_type=self.devicetype1,
+            device_role=self.devicerole1,
+            name='Test Device Null Region2',
+            site=self.site2,
+            cluster=self.cluster1
+        )
 
     def test_get_device(self):
 
@@ -1827,7 +1847,7 @@ class DeviceTest(APITestCase):
         url = reverse('dcim-api:device-list')
         response = self.client.get(url, **self.header)
 
-        self.assertEqual(response.data['count'], 4)
+        self.assertEqual(response.data['count'], 6)
 
     def test_list_devices_brief(self):
 
@@ -1839,20 +1859,19 @@ class DeviceTest(APITestCase):
             ['display_name', 'id', 'name', 'url']
         )
 
-    def test_list_device_null_region(self):
-
-        Device.objects.create(
-            device_type=self.devicetype1,
-            device_role=self.devicerole1,
-            name='Test Device Null Region',
-            site=self.site2,
-            cluster=self.cluster1
-        )
+    def test_list_devices_null_region(self):
 
         url = reverse('dcim-api:device-list')
         response = self.client.get('{}?region=null'.format(url), **self.header)
 
-        self.assertEqual(response.data['count'], 1)
+        self.assertEqual(response.data['count'], 2)
+
+    def test_list_devices_multiple_regions(self):
+
+        url = reverse('dcim-api:device-list')
+        response = self.client.get('{}?region=null&region=test-region-1'.format(url), **self.header)
+
+        self.assertEqual(response.data['count'], 6)
 
     def test_create_device(self):
 
@@ -1868,7 +1887,7 @@ class DeviceTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Device.objects.count(), 5)
+        self.assertEqual(Device.objects.count(), 7)
         device4 = Device.objects.get(pk=response.data['id'])
         self.assertEqual(device4.device_type_id, data['device_type'])
         self.assertEqual(device4.device_role_id, data['device_role'])
@@ -1903,7 +1922,7 @@ class DeviceTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(Device.objects.count(), 7)
+        self.assertEqual(Device.objects.count(), 9)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
@@ -1927,7 +1946,7 @@ class DeviceTest(APITestCase):
         response = self.client.put(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(Device.objects.count(), 4)
+        self.assertEqual(Device.objects.count(), 6)
         device1 = Device.objects.get(pk=response.data['id'])
         self.assertEqual(device1.device_type_id, data['device_type'])
         self.assertEqual(device1.device_role_id, data['device_role'])
@@ -1942,7 +1961,7 @@ class DeviceTest(APITestCase):
         response = self.client.delete(url, **self.header)
 
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(Device.objects.count(), 3)
+        self.assertEqual(Device.objects.count(), 5)
 
     def test_config_context_included_by_default_in_list_view(self):
 

+ 7 - 6
netbox/utilities/filters.py

@@ -62,13 +62,14 @@ class TreeNodeMultipleChoiceFilter(django_filters.ModelMultipleChoiceFilter):
     Filters for a set of Models, including all descendant models within a Tree.  Example: [<Region: R1>,<Region: R2>]
     """
 
-    def filter(self, qs, value):
-        if settings.FILTERS_NULL_CHOICE_VALUE in value:
-            # Filtering by null value. Example: region=null
-            qs = self.get_method(qs)(**{self.field_name.replace('in', 'isnull'): True})
-            return qs.distinct() if self.distinct else qs
+    def get_filter_predicate(self, v):
+        # null value filtering
+        if v is None:
+            return {self.field_name.replace('in', 'isnull'): True}
+        return super().get_filter_predicate(v)
 
-        value = [node.get_descendants(include_self=True) for node in value]
+    def filter(self, qs, value):
+        value = [node.get_descendants(include_self=True) if not isinstance(node, str) else node for node in value]
         return super().filter(qs, value)
 
 

+ 16 - 9
netbox/virtualization/tests/test_api.py

@@ -350,6 +350,8 @@ class VirtualMachineTest(APITestCase):
                 'B': 2
             }
         )
+        self.virtualmachine_non_region1 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region1', cluster=self.cluster2)
+        self.virtualmachine_non_region2 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region2', cluster=self.cluster2)
 
     def test_get_virtualmachine(self):
 
@@ -363,7 +365,7 @@ class VirtualMachineTest(APITestCase):
         url = reverse('virtualization-api:virtualmachine-list')
         response = self.client.get(url, **self.header)
 
-        self.assertEqual(response.data['count'], 4)
+        self.assertEqual(response.data['count'], 6)
 
     def test_list_virtualmachines_brief(self):
 
@@ -377,12 +379,17 @@ class VirtualMachineTest(APITestCase):
 
     def test_list_virtualmachines_null_region(self):
 
-        VirtualMachine.objects.create(name='Test Virtual Machine Null Region', cluster=self.cluster2)
-
         url = reverse('virtualization-api:virtualmachine-list')
         response = self.client.get('{}?region=null'.format(url), **self.header)
 
-        self.assertEqual(response.data['count'], 1)
+        self.assertEqual(response.data['count'], 2)
+
+    def test_list_virtualmachines_multiple_regions(self):
+
+        url = reverse('virtualization-api:virtualmachine-list')
+        response = self.client.get('{}?region=null&region=test-region-1'.format(url), **self.header)
+
+        self.assertEqual(response.data['count'], 6)
 
     def test_create_virtualmachine(self):
 
@@ -395,7 +402,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(VirtualMachine.objects.count(), 5)
+        self.assertEqual(VirtualMachine.objects.count(), 7)
         virtualmachine4 = VirtualMachine.objects.get(pk=response.data['id'])
         self.assertEqual(virtualmachine4.name, data['name'])
         self.assertEqual(virtualmachine4.cluster.pk, data['cluster'])
@@ -410,7 +417,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
-        self.assertEqual(VirtualMachine.objects.count(), 4)
+        self.assertEqual(VirtualMachine.objects.count(), 6)
 
     def test_create_virtualmachine_bulk(self):
 
@@ -433,7 +440,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.post(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(VirtualMachine.objects.count(), 7)
+        self.assertEqual(VirtualMachine.objects.count(), 9)
         self.assertEqual(response.data[0]['name'], data[0]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
@@ -455,7 +462,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.put(url, data, format='json', **self.header)
 
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(VirtualMachine.objects.count(), 4)
+        self.assertEqual(VirtualMachine.objects.count(), 6)
         virtualmachine1 = VirtualMachine.objects.get(pk=response.data['id'])
         self.assertEqual(virtualmachine1.name, data['name'])
         self.assertEqual(virtualmachine1.cluster.pk, data['cluster'])
@@ -468,7 +475,7 @@ class VirtualMachineTest(APITestCase):
         response = self.client.delete(url, **self.header)
 
         self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
-        self.assertEqual(VirtualMachine.objects.count(), 3)
+        self.assertEqual(VirtualMachine.objects.count(), 5)
 
     def test_config_context_included_by_default_in_list_view(self):