|
@@ -1,12 +1,15 @@
|
|
|
import json
|
|
import json
|
|
|
|
|
|
|
|
|
|
+import strawberry
|
|
|
from django.test import override_settings
|
|
from django.test import override_settings
|
|
|
from django.urls import reverse
|
|
from django.urls import reverse
|
|
|
from rest_framework import status
|
|
from rest_framework import status
|
|
|
|
|
+from strawberry.types.lazy_type import LazyType
|
|
|
|
|
|
|
|
from core.models import ObjectType
|
|
from core.models import ObjectType
|
|
|
from dcim.choices import LocationStatusChoices
|
|
from dcim.choices import LocationStatusChoices
|
|
|
from dcim.models import Site, Location
|
|
from dcim.models import Site, Location
|
|
|
|
|
+from netbox.graphql.schema import QueryV1, QueryV2
|
|
|
from users.models import ObjectPermission
|
|
from users.models import ObjectPermission
|
|
|
from utilities.testing import disable_warnings, APITestCase, TestCase
|
|
from utilities.testing import disable_warnings, APITestCase, TestCase
|
|
|
|
|
|
|
@@ -45,6 +48,53 @@ class GraphQLTestCase(TestCase):
|
|
|
|
|
|
|
|
class GraphQLAPITestCase(APITestCase):
|
|
class GraphQLAPITestCase(APITestCase):
|
|
|
|
|
|
|
|
|
|
+ def test_versioned_types(self):
|
|
|
|
|
+ """
|
|
|
|
|
+ Check that the GraphQL types defined for each version of the schema (V1 and V2) are correct.
|
|
|
|
|
+ """
|
|
|
|
|
+ schemas = (
|
|
|
|
|
+ (1, QueryV1),
|
|
|
|
|
+ (2, QueryV2),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def _get_class_name(field):
|
|
|
|
|
+ try:
|
|
|
|
|
+ if type(field.type) is strawberry.types.base.StrawberryList:
|
|
|
|
|
+ # Skip scalars
|
|
|
|
|
+ if field.type.of_type in (str, int):
|
|
|
|
|
+ return
|
|
|
|
|
+ if type(field.type.of_type) is LazyType:
|
|
|
|
|
+ return field.type.of_type.type_name
|
|
|
|
|
+ return field.type.of_type.__name__
|
|
|
|
|
+ if hasattr(field.type, 'name'):
|
|
|
|
|
+ return field.type.__name__
|
|
|
|
|
+ except AttributeError:
|
|
|
|
|
+ # Unknown field type
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ def _check_version(class_name, version):
|
|
|
|
|
+ if version == 1:
|
|
|
|
|
+ self.assertTrue(class_name.endswith('V1'), f"{class_name} (v1) is not a V1 type")
|
|
|
|
|
+ elif version == 2:
|
|
|
|
|
+ self.assertFalse(class_name.endswith('V1'), f"{class_name} (v2) is a V1 type")
|
|
|
|
|
+
|
|
|
|
|
+ for version, query in schemas:
|
|
|
|
|
+ schema = strawberry.Schema(query=query)
|
|
|
|
|
+ query_type = schema.get_type_by_name(query.__name__)
|
|
|
|
|
+
|
|
|
|
|
+ # Iterate through root fields
|
|
|
|
|
+ for field in query_type.fields:
|
|
|
|
|
+ # Check for V1 suffix on class names
|
|
|
|
|
+ if type_class := _get_class_name(field):
|
|
|
|
|
+ _check_version(type_class, version)
|
|
|
|
|
+
|
|
|
|
|
+ # Iterate through nested fields
|
|
|
|
|
+ subquery_type = schema.get_type_by_name(type_class)
|
|
|
|
|
+ for subfield in subquery_type.fields:
|
|
|
|
|
+ # Check for V1 suffix on class names
|
|
|
|
|
+ if type_class := _get_class_name(subfield):
|
|
|
|
|
+ _check_version(type_class, version)
|
|
|
|
|
+
|
|
|
@override_settings(LOGIN_REQUIRED=True)
|
|
@override_settings(LOGIN_REQUIRED=True)
|
|
|
def test_graphql_filter_objects(self):
|
|
def test_graphql_filter_objects(self):
|
|
|
"""
|
|
"""
|