Browse Source

19724 add the v2 to graphql testing

Arthur 3 months ago
parent
commit
af55da008b
1 changed files with 112 additions and 55 deletions
  1. 112 55
      netbox/utilities/testing/api.py

+ 112 - 55
netbox/utilities/testing/api.py

@@ -515,10 +515,15 @@ class APIViewTestCases:
             base_name = self.model._meta.verbose_name.lower().replace(' ', '_')
             return getattr(self, 'graphql_base_name', base_name)
 
-        def _build_query_with_filter(self, name, filter_string):
+        def _build_query_with_filter(self, name, filter_string, api_version='v2'):
             """
             Called by either _build_query or _build_filtered_query - construct the actual
             query given a name and filter string
+
+            Args:
+                name: The query field name (e.g., 'device_list')
+                filter_string: Filter parameters string (e.g., '(filters: {id: "1"})')
+                api_version: 'v1' or 'v2' to determine response format
             """
             type_class = get_graphql_type_for_model(self.model)
 
@@ -564,16 +569,26 @@ class APIViewTestCases:
 
             # Check if this is a list query (ends with '_list')
             if name.endswith('_list'):
-                # Wrap fields in 'results' for paginated queries
-                query = f"""
-                {{
-                    {name}{filter_string} {{
-                        results {{
+                if api_version == 'v2':
+                    # v2: Wrap fields in 'results' for paginated queries
+                    query = f"""
+                    {{
+                        {name}{filter_string} {{
+                            results {{
+                                {fields_string}
+                            }}
+                        }}
+                    }}
+                    """
+                else:
+                    # v1: Return direct array (no 'results' wrapper)
+                    query = f"""
+                    {{
+                        {name}{filter_string} {{
                             {fields_string}
                         }}
                     }}
-                }}
-                """
+                    """
             else:
                 # Single object query (no pagination)
                 query = f"""
@@ -586,9 +601,14 @@ class APIViewTestCases:
 
             return query
 
-        def _build_filtered_query(self, name, **filters):
+        def _build_filtered_query(self, name, api_version='v2', **filters):
             """
             Create a filtered query: i.e. device_list(filters: {name: {i_contains: "akron"}}){.
+
+            Args:
+                name: The query field name
+                api_version: 'v1' or 'v2' to determine response format
+                **filters: Filter parameters
             """
             # TODO: This should be extended to support AND, OR multi-lookups
             if filters:
@@ -604,11 +624,16 @@ class APIViewTestCases:
             else:
                 filter_string = ''
 
-            return self._build_query_with_filter(name, filter_string)
+            return self._build_query_with_filter(name, filter_string, api_version)
 
-        def _build_query(self, name, **filters):
+        def _build_query(self, name, api_version='v2', **filters):
             """
             Create a normal query - unfiltered or with a string query: i.e. site(name: "aaa"){.
+
+            Args:
+                name: The query field name
+                api_version: 'v1' or 'v2' to determine response format
+                **filters: Filter parameters
             """
             if filters:
                 filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
@@ -616,7 +641,7 @@ class APIViewTestCases:
             else:
                 filter_string = ''
 
-            return self._build_query_with_filter(name, filter_string)
+            return self._build_query_with_filter(name, filter_string, api_version)
 
         @override_settings(LOGIN_REQUIRED=True)
         def test_graphql_get_object(self):
@@ -664,54 +689,71 @@ class APIViewTestCases:
 
         @override_settings(LOGIN_REQUIRED=True)
         def test_graphql_list_objects(self):
-            url = reverse('graphql_v2')
             field_name = f'{self._get_graphql_base_name()}_list'
-            query = self._build_query(field_name)
-
-            # Non-authenticated requests should fail
-            header = {
-                'HTTP_ACCEPT': 'application/json',
-            }
-            with disable_warnings('django.request'):
-                response = self.client.post(url, data={'query': query}, format="json", **header)
-            self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
-
-            # Add constrained permission
-            obj_perm = ObjectPermission(
-                name='Test permission',
-                actions=['view'],
-                constraints={'id': 0}  # Impossible constraint
-            )
-            obj_perm.save()
-            obj_perm.users.add(self.user)
-            obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
 
-            # Request should succeed but return empty results list
-            response = self.client.post(url, data={'query': query}, format="json", **self.header)
-            self.assertHttpStatus(response, status.HTTP_200_OK)
-            data = json.loads(response.content)
-            self.assertNotIn('errors', data)
-            self.assertEqual(len(data['data'][field_name]['results']), 0)
-
-            # Remove permission constraint
-            obj_perm.constraints = None
-            obj_perm.save()
+            # Test both GraphQL API versions
+            for api_version, url_name in [('v1', 'graphql_v1'), ('v2', 'graphql_v2')]:
+                with self.subTest(api_version=api_version):
+                    url = reverse(url_name)
+                    query = self._build_query(field_name, api_version=api_version)
+
+                    # Non-authenticated requests should fail
+                    header = {
+                        'HTTP_ACCEPT': 'application/json',
+                    }
+                    with disable_warnings('django.request'):
+                        response = self.client.post(url, data={'query': query}, format="json", **header)
+                    self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
+
+                    # Add constrained permission
+                    obj_perm = ObjectPermission(
+                        name='Test permission',
+                        actions=['view'],
+                        constraints={'id': 0}  # Impossible constraint
+                    )
+                    obj_perm.save()
+                    obj_perm.users.add(self.user)
+                    obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+                    # Request should succeed but return empty results list
+                    response = self.client.post(url, data={'query': query}, format="json", **self.header)
+                    self.assertHttpStatus(response, status.HTTP_200_OK)
+                    data = json.loads(response.content)
+                    self.assertNotIn('errors', data)
+
+                    if api_version == 'v1':
+                        # v1 returns direct array
+                        self.assertEqual(len(data['data'][field_name]), 0)
+                    else:
+                        # v2 returns paginated response with results
+                        self.assertEqual(len(data['data'][field_name]['results']), 0)
+
+                    # Remove permission constraint
+                    obj_perm.constraints = None
+                    obj_perm.save()
+
+                    # Request should return all objects
+                    response = self.client.post(url, data={'query': query}, format="json", **self.header)
+                    self.assertHttpStatus(response, status.HTTP_200_OK)
+                    data = json.loads(response.content)
+                    self.assertNotIn('errors', data)
+
+                    if api_version == 'v1':
+                        # v1 returns direct array
+                        self.assertEqual(len(data['data'][field_name]), self.model.objects.count())
+                    else:
+                        # v2 returns paginated response with results
+                        self.assertEqual(len(data['data'][field_name]['results']), self.model.objects.count())
 
-            # Request should return all objects
-            response = self.client.post(url, data={'query': query}, format="json", **self.header)
-            self.assertHttpStatus(response, status.HTTP_200_OK)
-            data = json.loads(response.content)
-            self.assertNotIn('errors', data)
-            self.assertEqual(len(data['data'][field_name]['results']), self.model.objects.count())
+                    # Clean up permission for next iteration
+                    obj_perm.delete()
 
         @override_settings(LOGIN_REQUIRED=True)
         def test_graphql_filter_objects(self):
             if not hasattr(self, 'graphql_filter'):
                 return
 
-            url = reverse('graphql_v2')
             field_name = f'{self._get_graphql_base_name()}_list'
-            query = self._build_filtered_query(field_name, **self.graphql_filter)
 
             # Add object-level permission
             obj_perm = ObjectPermission(
@@ -722,11 +764,26 @@ class APIViewTestCases:
             obj_perm.users.add(self.user)
             obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
 
-            response = self.client.post(url, data={'query': query}, format="json", **self.header)
-            self.assertHttpStatus(response, status.HTTP_200_OK)
-            data = json.loads(response.content)
-            self.assertNotIn('errors', data)
-            self.assertGreater(len(data['data'][field_name]['results']), 0)
+            # Test both GraphQL API versions
+            for api_version, url_name in [('v1', 'graphql_v1'), ('v2', 'graphql_v2')]:
+                with self.subTest(api_version=api_version):
+                    url = reverse(url_name)
+                    query = self._build_filtered_query(field_name, api_version=api_version, **self.graphql_filter)
+
+                    response = self.client.post(url, data={'query': query}, format="json", **self.header)
+                    self.assertHttpStatus(response, status.HTTP_200_OK)
+                    data = json.loads(response.content)
+                    self.assertNotIn('errors', data)
+
+                    if api_version == 'v1':
+                        # v1 returns direct array
+                        self.assertGreater(len(data['data'][field_name]), 0)
+                    else:
+                        # v2 returns paginated response with results
+                        self.assertGreater(len(data['data'][field_name]['results']), 0)
+
+            # Clean up permission
+            obj_perm.delete()
 
     class APIViewTestCase(
         GetObjectViewTestCase,