Преглед изворни кода

Added initial GraphQL tests

jeremystretch пре 4 година
родитељ
комит
91d39cc0c0

+ 5 - 5
netbox/circuits/tests/test_api.py

@@ -15,7 +15,7 @@ class AppTest(APITestCase):
         self.assertEqual(response.status_code, 200)
 
 
-class ProviderTest(APIViewTestCases.APIViewTestCase):
+class ProviderTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
     model = Provider
     brief_fields = ['circuit_count', 'display', 'id', 'name', 'slug', 'url']
     create_data = [
@@ -47,7 +47,7 @@ class ProviderTest(APIViewTestCases.APIViewTestCase):
         Provider.objects.bulk_create(providers)
 
 
-class CircuitTypeTest(APIViewTestCases.APIViewTestCase):
+class CircuitTypeTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
     model = CircuitType
     brief_fields = ['circuit_count', 'display', 'id', 'name', 'slug', 'url']
     create_data = (
@@ -79,7 +79,7 @@ class CircuitTypeTest(APIViewTestCases.APIViewTestCase):
         CircuitType.objects.bulk_create(circuit_types)
 
 
-class CircuitTest(APIViewTestCases.APIViewTestCase):
+class CircuitTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
     model = Circuit
     brief_fields = ['cid', 'display', 'id', 'url']
     bulk_update_data = {
@@ -127,7 +127,7 @@ class CircuitTest(APIViewTestCases.APIViewTestCase):
         ]
 
 
-class CircuitTerminationTest(APIViewTestCases.APIViewTestCase):
+class CircuitTerminationTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
     model = CircuitTermination
     brief_fields = ['_occupied', 'cable', 'circuit', 'display', 'id', 'term_side', 'url']
 
@@ -180,7 +180,7 @@ class CircuitTerminationTest(APIViewTestCases.APIViewTestCase):
         }
 
 
-class ProviderNetworkTest(APIViewTestCases.APIViewTestCase):
+class ProviderNetworkTest(APIViewTestCases.GraphQLTestCase, APIViewTestCases.APIViewTestCase):
     model = ProviderNetwork
     brief_fields = ['display', 'id', 'name', 'url']
 

+ 27 - 0
netbox/netbox/tests/test_graphql.py

@@ -0,0 +1,27 @@
+from django.test import override_settings
+from django.urls import reverse
+
+from utilities.testing import disable_warnings, TestCase
+
+
+class GraphQLTestCase(TestCase):
+
+    @override_settings(LOGIN_REQUIRED=True)
+    def test_graphiql_interface(self):
+        """
+        Test rendering of the GraphiQL interactive web interface
+        """
+        url = reverse('graphql')
+        header = {
+            'HTTP_ACCEPT': 'text/html',
+        }
+
+        # Authenticated request
+        response = self.client.get(url, **header)
+        self.assertHttpStatus(response, 200)
+
+        # Non-authenticated request
+        self.client.logout()
+        response = self.client.get(url, **header)
+        with disable_warnings('django.request'):
+            self.assertHttpStatus(response, 302)

+ 1 - 1
netbox/netbox/urls.py

@@ -63,7 +63,7 @@ _patterns = [
     re_path(r'^api/swagger(?P<format>.json|.yaml)$', schema_view.without_ui(), name='schema_swagger'),
 
     # GraphQL
-    path('graphql/', GraphQLView.as_view(graphiql=True, schema=schema)),
+    path('graphql/', GraphQLView.as_view(graphiql=True, schema=schema), name='graphql'),
 
     # Serving static media in Django to pipe it through LoginRequiredMiddleware
     path('media/<path:path>', serve, {'document_root': settings.MEDIA_ROOT}),

+ 45 - 0
netbox/utilities/testing/api.py

@@ -1,3 +1,5 @@
+import json
+
 from django.conf import settings
 from django.contrib.auth.models import User
 from django.contrib.contenttypes.models import ContentType
@@ -421,6 +423,49 @@ class APIViewTestCases:
             self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
             self.assertEqual(self._get_queryset().count(), initial_count - 3)
 
+    class GraphQLTestCase(APITestCase):
+
+        def test_graphql_get_object(self):
+            url = reverse('graphql')
+            object_type = self.model._meta.verbose_name.replace(' ', '_')
+            object_id = self._get_queryset().first().pk
+            query = f"""
+            {{
+                {object_type}(id:{object_id}) {{
+                    id
+                }}
+            }}
+            """
+
+            # Non-authenticated requests should fail
+            with disable_warnings('django.request'):
+                self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
+
+            response = self.client.post(url, data={'query': query}, **self.header)
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+            data = json.loads(response.content)
+            self.assertNotIn('errors', data)
+
+        def test_graphql_list_objects(self):
+            url = reverse('graphql')
+            object_type = self.model._meta.verbose_name_plural.replace(' ', '_')
+            query = f"""
+            {{
+                {object_type} {{
+                    id
+                }}
+            }}
+            """
+
+            # Non-authenticated requests should fail
+            with disable_warnings('django.request'):
+                self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
+
+            response = self.client.post(url, data={'query': query}, **self.header)
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+            data = json.loads(response.content)
+            self.assertNotIn('errors', data)
+
     class APIViewTestCase(
         GetObjectViewTestCase,
         ListObjectsViewTestCase,