Selaa lähdekoodia

16078 make GraphQL NumberFilter optional (#16115)

* 16078 make GraphQL NumberFilter optional

* 16078 add tests for graphql filtering

* 16078 add tests for graphql filtering

* 16078 add tests for graphql filtering
Arthur Hanson 1 vuosi sitten
vanhempi
commit
b291aa4312

+ 3 - 0
netbox/ipam/tests/test_api.py

@@ -648,6 +648,9 @@ class IPAddressTest(APIViewTestCases.APIViewTestCase):
     bulk_update_data = {
         'description': 'New description',
     }
+    graphql_filter = {
+        'address': '192.168.0.1/24',
+    }
 
     @classmethod
     def setUpTestData(cls):

+ 1 - 1
netbox/netbox/graphql/filter_mixins.py

@@ -87,7 +87,7 @@ def map_strawberry_type(field):
         pass
     elif issubclass(type(field), django_filters.NumberFilter):
         should_create_function = True
-        attr_type = int
+        attr_type = int | None
     elif issubclass(type(field), django_filters.ModelMultipleChoiceFilter):
         should_create_function = True
         attr_type = List[str] | None

+ 54 - 6
netbox/utilities/testing/api.py

@@ -440,13 +440,12 @@ class APIViewTestCases:
             base_name = self.model._meta.verbose_name.lower().replace(' ', '_')
             return getattr(self, 'graphql_base_name', base_name)
 
-        def _build_query(self, name, **filters):
+        def _build_query_with_filter(self, name, filter_string):
+            """
+            Called by either _build_query or _build_filtered_query - construct the actual
+            query given a name and filter string
+            """
             type_class = get_graphql_type_for_model(self.model)
-            if filters:
-                filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
-                filter_string = f'({filter_string})'
-            else:
-                filter_string = ''
 
             # Compile list of fields to include
             fields_string = ''
@@ -492,6 +491,30 @@ class APIViewTestCases:
 
             return query
 
+        def _build_filtered_query(self, name, **filters):
+            """
+            Create a filtered query: i.e. ip_address_list(filters: {address: "1.1.1.1/24"}){.
+            """
+            if filters:
+                filter_string = ', '.join(f'{k}: "{v}"' for k, v in filters.items())
+                filter_string = f'(filters: {{{filter_string}}})'
+            else:
+                filter_string = ''
+
+            return self._build_query_with_filter(name, filter_string)
+
+        def _build_query(self, name, **filters):
+            """
+            Create a normal query - unfiltered or with a string query: i.e. site(name: "aaa"){.
+            """
+            if filters:
+                filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
+                filter_string = f'({filter_string})'
+            else:
+                filter_string = ''
+
+            return self._build_query_with_filter(name, filter_string)
+
         @override_settings(LOGIN_REQUIRED=True)
         @override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
         def test_graphql_get_object(self):
@@ -550,6 +573,31 @@ class APIViewTestCases:
             self.assertNotIn('errors', data)
             self.assertGreater(len(data['data'][field_name]), 0)
 
+        @override_settings(LOGIN_REQUIRED=True)
+        @override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
+        def test_graphql_filter_objects(self):
+            if not hasattr(self, 'graphql_filter'):
+                return
+
+            url = reverse('graphql')
+            field_name = f'{self._get_graphql_base_name()}_list'
+            query = self._build_filtered_query(field_name, **self.graphql_filter)
+
+            # Add object-level permission
+            obj_perm = ObjectPermission(
+                name='Test permission',
+                actions=['view']
+            )
+            obj_perm.save()
+            obj_perm.users.add(self.user)
+            obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+            response = self.client.post(url, data={'query': query}, format="json", **self.header)
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+            data = json.loads(response.content)
+            self.assertNotIn('errors', data)
+            self.assertGreater(len(data['data'][field_name]), 0)
+
     class APIViewTestCase(
         GetObjectViewTestCase,
         ListObjectsViewTestCase,