Просмотр исходного кода

Extended GraphQL tests to include all fields

jeremystretch 4 лет назад
Родитель
Сommit
7deabfe9cd
3 измененных файлов с 52 добавлено и 16 удалено
  1. 4 0
      netbox/netbox/api/exceptions.py
  2. 14 2
      netbox/utilities/api.py
  3. 34 14
      netbox/utilities/testing/api.py

+ 4 - 0
netbox/netbox/api/exceptions.py

@@ -8,3 +8,7 @@ class ServiceUnavailable(APIException):
 
 class SerializerNotFound(Exception):
     pass
+
+
+class GraphQLTypeNotFound(Exception):
+    pass

+ 14 - 2
netbox/utilities/api.py

@@ -7,7 +7,7 @@ from django.urls import reverse
 from rest_framework import status
 from rest_framework.utils import formatting
 
-from netbox.api.exceptions import SerializerNotFound
+from netbox.api.exceptions import GraphQLTypeNotFound, SerializerNotFound
 from .utils import dynamic_import
 
 
@@ -24,10 +24,22 @@ def get_serializer_for_model(model, prefix=''):
         return dynamic_import(serializer_name)
     except AttributeError:
         raise SerializerNotFound(
-            "Could not determine serializer for {}.{} with prefix '{}'".format(app_name, model_name, prefix)
+            f"Could not determine serializer for {app_name}.{model_name} with prefix '{prefix}'"
         )
 
 
+def get_graphql_type_for_model(model):
+    """
+    Return the GraphQL type class for the given model.
+    """
+    app_name, model_name = model._meta.label.split('.')
+    class_name = f'{app_name}.graphql.types.{model_name}Type'
+    try:
+        return dynamic_import(class_name)
+    except AttributeError:
+        raise GraphQLTypeNotFound(f"Could not find GraphQL type for {app_name}.{model_name}")
+
+
 def is_api_request(request):
     """
     Return True of the request is being made via the REST API.

+ 34 - 14
netbox/utilities/testing/api.py

@@ -5,12 +5,14 @@ from django.contrib.auth.models import User
 from django.contrib.contenttypes.models import ContentType
 from django.urls import reverse
 from django.test import override_settings
+from graphene.types.dynamic import Dynamic
 from rest_framework import status
 from rest_framework.test import APIClient
 
 from extras.choices import ObjectChangeActionChoices
 from extras.models import ObjectChange
 from users.models import ObjectPermission, Token
+from utilities.api import get_graphql_type_for_model
 from .base import ModelTestCase
 from .utils import disable_warnings
 
@@ -431,19 +433,43 @@ class APIViewTestCases:
                                self.model._meta.verbose_name_plural.lower().replace(' ', '_'))
             return getattr(self, 'graphql_base_name', self.model._meta.verbose_name.lower().replace(' ', '_'))
 
-        @override_settings(LOGIN_REQUIRED=True)
-        def test_graphql_get_object(self):
-            url = reverse('graphql')
-            object_type = self._get_graphql_base_name()
-            object_id = self._get_queryset().first().pk
+        def _build_query(self, name, **filters):
+            type_class = get_graphql_type_for_model(self.model)
+            if filters:
+                filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
+                filter_string = f'({filter_string})'
+            else:
+                filter_string = ''
+
+            # Compile list of fields to include
+            fields_string = ''
+            for field_name, field in type_class._meta.fields.items():
+                # TODO: Omit "hidden" fields from GraphQL types
+                if field_name.startswith('_'):
+                    continue
+                if type(field) is Dynamic:
+                    # Dynamic fields must specify a subselection
+                    fields_string += f'{field_name} {{ id }}\n'
+                else:
+                    fields_string += f'{field_name}\n'
+
             query = f"""
             {{
-                {object_type}(id:{object_id}) {{
-                    id
+                {name}{filter_string} {{
+                    {fields_string}
                 }}
             }}
             """
 
+            return query
+
+        @override_settings(LOGIN_REQUIRED=True)
+        def test_graphql_get_object(self):
+            url = reverse('graphql')
+            object_type = self._get_graphql_base_name()
+            object_id = self._get_queryset().first().pk
+            query = self._build_query(object_type, id=object_id)
+
             # Non-authenticated requests should fail
             with disable_warnings('django.request'):
                 self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
@@ -466,13 +492,7 @@ class APIViewTestCases:
         def test_graphql_list_objects(self):
             url = reverse('graphql')
             object_type = self._get_graphql_base_name(plural=True)
-            query = f"""
-            {{
-                {object_type} {{
-                    id
-                }}
-            }}
-            """
+            query = self._build_query(object_type)
 
             # Non-authenticated requests should fail
             with disable_warnings('django.request'):