Quellcode durchsuchen

Add a test to validate versioned GraphQL types

Jeremy Stretch vor 3 Monaten
Ursprung
Commit
47ac506d5c
1 geänderte Dateien mit 50 neuen und 0 gelöschten Zeilen
  1. 50 0
      netbox/netbox/tests/test_graphql.py

+ 50 - 0
netbox/netbox/tests/test_graphql.py

@@ -1,12 +1,15 @@
 import json
 import json
 
 
+import strawberry
 from django.test import override_settings
 from django.test import override_settings
 from django.urls import reverse
 from django.urls import reverse
 from rest_framework import status
 from rest_framework import status
+from strawberry.types.lazy_type import LazyType
 
 
 from core.models import ObjectType
 from core.models import ObjectType
 from dcim.choices import LocationStatusChoices
 from dcim.choices import LocationStatusChoices
 from dcim.models import Site, Location
 from dcim.models import Site, Location
+from netbox.graphql.schema import QueryV1, QueryV2
 from users.models import ObjectPermission
 from users.models import ObjectPermission
 from utilities.testing import disable_warnings, APITestCase, TestCase
 from utilities.testing import disable_warnings, APITestCase, TestCase
 
 
@@ -45,6 +48,53 @@ class GraphQLTestCase(TestCase):
 
 
 class GraphQLAPITestCase(APITestCase):
 class GraphQLAPITestCase(APITestCase):
 
 
+    def test_versioned_types(self):
+        """
+        Check that the GraphQL types defined for each version of the schema (V1 and V2) are correct.
+        """
+        schemas = (
+            (1, QueryV1),
+            (2, QueryV2),
+        )
+
+        def _get_class_name(field):
+            try:
+                if type(field.type) is strawberry.types.base.StrawberryList:
+                    # Skip scalars
+                    if field.type.of_type in (str, int):
+                        return
+                    if type(field.type.of_type) is LazyType:
+                        return field.type.of_type.type_name
+                    return field.type.of_type.__name__
+                if hasattr(field.type, 'name'):
+                    return field.type.__name__
+            except AttributeError:
+                # Unknown field type
+                return
+
+        def _check_version(class_name, version):
+            if version == 1:
+                self.assertTrue(class_name.endswith('V1'), f"{class_name} (v1) is not a V1 type")
+            elif version == 2:
+                self.assertFalse(class_name.endswith('V1'), f"{class_name} (v2) is a V1 type")
+
+        for version, query in schemas:
+            schema = strawberry.Schema(query=query)
+            query_type = schema.get_type_by_name(query.__name__)
+
+            # Iterate through root fields
+            for field in query_type.fields:
+                # Check for V1 suffix on class names
+                if type_class := _get_class_name(field):
+                    _check_version(type_class, version)
+
+                    # Iterate through nested fields
+                    subquery_type = schema.get_type_by_name(type_class)
+                    for subfield in subquery_type.fields:
+                        # Check for V1 suffix on class names
+                        if type_class := _get_class_name(subfield):
+                            _check_version(type_class, version)
+
     @override_settings(LOGIN_REQUIRED=True)
     @override_settings(LOGIN_REQUIRED=True)
     def test_graphql_filter_objects(self):
     def test_graphql_filter_objects(self):
         """
         """