Просмотр исходного кода

Merge pull request #22455 from netbox-community/22451-pass-strawberry-graphql-extension-factories-instead-of

Closes #22451: Use factories for GraphQL schema extension initialization
bctiemann 2 недель назад
Родитель
Сommit
2d496ca069
2 измененных файлов с 20 добавлено и 11 удалено
  1. 12 6
      netbox/netbox/graphql/schema.py
  2. 8 5
      netbox/netbox/tests/test_graphql.py

+ 12 - 6
netbox/netbox/graphql/schema.py

@@ -1,3 +1,5 @@
+from collections.abc import Callable
+
 import strawberry
 from django.conf import settings
 from strawberry.extensions import MaxAliasesLimiter, QueryDepthLimiter, SchemaExtension
@@ -18,6 +20,8 @@ from wireless.graphql.schema import WirelessQuery
 
 from .scalars import BigInt, BigIntScalar
 
+SchemaExtensionFactory = type[SchemaExtension] | Callable[[], SchemaExtension]
+
 
 @strawberry.type
 class Query(
@@ -36,14 +40,16 @@ class Query(
     pass
 
 
-def get_schema_extensions() -> list[SchemaExtension]:
-    extensions: list[SchemaExtension] = [
-        DjangoOptimizerExtension(prefetch_custom_queryset=True),
-        MaxAliasesLimiter(max_alias_count=settings.GRAPHQL_MAX_ALIASES),
-    ]
+def get_schema_extensions() -> list[SchemaExtensionFactory]:
+    max_aliases = settings.GRAPHQL_MAX_ALIASES
     max_depth = settings.GRAPHQL_MAX_QUERY_DEPTH
+
+    extensions: list[SchemaExtensionFactory] = [
+        lambda: DjangoOptimizerExtension(prefetch_custom_queryset=True),
+        lambda: MaxAliasesLimiter(max_alias_count=max_aliases),
+    ]
     if max_depth and max_depth > 0:
-        extensions.append(QueryDepthLimiter(max_depth=max_depth))
+        extensions.append(lambda: QueryDepthLimiter(max_depth=max_depth))
     return extensions
 
 

+ 8 - 5
netbox/netbox/tests/test_graphql.py

@@ -19,6 +19,9 @@ from utilities.testing import APITestCase, TestCase, disable_warnings
 
 class GraphQLTestCase(TestCase):
 
+    def _schema_extension_instances(self):
+        return [factory() for factory in get_schema_extensions()]
+
     @override_settings(GRAPHQL_ENABLED=False)
     def test_graphql_enabled(self):
         """
@@ -32,21 +35,21 @@ class GraphQLTestCase(TestCase):
         """
         QueryDepthLimiter should not be installed when GRAPHQL_MAX_QUERY_DEPTH is unset.
         """
-        self.assertFalse(any(isinstance(ext, QueryDepthLimiter) for ext in get_schema_extensions()))
+        self.assertFalse(any(isinstance(ext, QueryDepthLimiter) for ext in self._schema_extension_instances()))
 
     @override_settings(GRAPHQL_MAX_QUERY_DEPTH=0)
     def test_graphql_max_query_depth_disabled_when_zero(self):
         """
         QueryDepthLimiter should not be installed when GRAPHQL_MAX_QUERY_DEPTH is zero.
         """
-        self.assertFalse(any(isinstance(ext, QueryDepthLimiter) for ext in get_schema_extensions()))
+        self.assertFalse(any(isinstance(ext, QueryDepthLimiter) for ext in self._schema_extension_instances()))
 
     @override_settings(GRAPHQL_MAX_QUERY_DEPTH=-1)
     def test_graphql_max_query_depth_disabled_when_negative(self):
         """
         QueryDepthLimiter should not be installed when GRAPHQL_MAX_QUERY_DEPTH is negative.
         """
-        self.assertFalse(any(isinstance(ext, QueryDepthLimiter) for ext in get_schema_extensions()))
+        self.assertFalse(any(isinstance(ext, QueryDepthLimiter) for ext in self._schema_extension_instances()))
 
     @override_settings(GRAPHQL_MAX_QUERY_DEPTH=3)
     def test_graphql_max_query_depth_enforced(self):
@@ -54,9 +57,9 @@ class GraphQLTestCase(TestCase):
         Queries exceeding GRAPHQL_MAX_QUERY_DEPTH should be rejected.
         """
         extensions = get_schema_extensions()
-        self.assertTrue(any(isinstance(ext, QueryDepthLimiter) for ext in extensions))
+        self.assertTrue(any(isinstance(ext, QueryDepthLimiter) for ext in self._schema_extension_instances()))
 
-        # Build a temporary schema with the configured extensions and execute a deep query
+        # Build a temporary schema with the configured extension factories and execute a deep query
         test_schema = strawberry.Schema(
             query=Query,
             config=StrawberryConfig(auto_camel_case=False, scalar_map={BigInt: BigIntScalar}),