Arthur před 1 rokem
rodič
revize
14f04453bb

+ 1 - 1
netbox/dcim/graphql/types.py

@@ -541,7 +541,7 @@ class LocationType(VLANGroupsMixin, ImageAttachmentsMixin, ContactsMixin, Organi
         return self.vlan_groups.all()
 
     @strawberry_django.field
-    def parent(self) -> Annotated["LocationType", strawberry.lazy('dcim.graphql.types')]:
+    def parent(self) -> Annotated["LocationType", strawberry.lazy('dcim.graphql.types')] | None:
         return self.parent
 
     @strawberry_django.field

+ 6 - 6
netbox/extras/graphql/mixins.py

@@ -24,13 +24,13 @@ if TYPE_CHECKING:
 class ChangelogMixin:
 
     @strawberry_django.field
-    def changelog(self) -> List[Annotated["ObjectChangeType", strawberry.lazy('.types')]]:
+    def changelog(self, info) -> List[Annotated["ObjectChangeType", strawberry.lazy('.types')]]:
         content_type = ContentType.objects.get_for_model(self)
         object_changes = ObjectChange.objects.filter(
             changed_object_type=content_type,
             changed_object_id=self.pk
         )
-        return object_changes.restrict(info.context.user, 'view')
+        return object_changes.restrict(info.context.request.user, 'view')
 
 
 @strawberry.type
@@ -53,16 +53,16 @@ class CustomFieldsMixin:
 class ImageAttachmentsMixin:
 
     @strawberry_django.field
-    def image_attachments(self) -> List[Annotated["ImageAttachmentType", strawberry.lazy('.types')]]:
-        return self.images.restrict(info.context.user, 'view')
+    def image_attachments(self, info) -> List[Annotated["ImageAttachmentType", strawberry.lazy('.types')]]:
+        return self.images.restrict(info.context.request.user, 'view')
 
 
 @strawberry.type
 class JournalEntriesMixin:
 
     @strawberry_django.field
-    def journal_entries(self) -> List[Annotated["JournalEntryType", strawberry.lazy('.types')]]:
-        return self.journal_entries.restrict(info.context.user, 'view')
+    def journal_entries(self, info) -> List[Annotated["JournalEntryType", strawberry.lazy('.types')]]:
+        return self.journal_entries.restrict(info.context.request.user, 'view')
 
 
 @strawberry.type

+ 2 - 2
netbox/ipam/graphql/mixins.py

@@ -10,11 +10,11 @@ class IPAddressesMixin:
     ip_addresses = graphene.List('ipam.graphql.types.IPAddressType')
 
     def resolve_ip_addresses(self, info):
-        return self.ip_addresses.restrict(info.context.user, 'view')
+        return self.ip_addresses.restrict(info.context.request.user, 'view')
 
 
 class VLANGroupsMixin:
     vlan_groups = graphene.List('ipam.graphql.types.VLANGroupType')
 
     def resolve_vlan_groups(self, info):
-        return self.vlan_groups.restrict(info.context.user, 'view')
+        return self.vlan_groups.restrict(info.context.request.user, 'view')

+ 5 - 3
netbox/netbox/graphql/views.py

@@ -2,19 +2,21 @@ from django.conf import settings
 from django.contrib.auth.views import redirect_to_login
 from django.http import HttpResponseNotFound, HttpResponseForbidden
 from django.urls import reverse
-from graphene_django.views import GraphQLView as GraphQLView_
+from django.views.decorators.csrf import csrf_exempt
 from rest_framework.exceptions import AuthenticationFailed
+from strawberry.django.views import GraphQLView
 
 from netbox.api.authentication import TokenAuthentication
 from netbox.config import get_config
 
 
-class GraphQLView(GraphQLView_):
+class NetBoxGraphQLView(GraphQLView):
     """
-    Extends graphene_django's GraphQLView to support DRF's token-based authentication.
+    Extends strawberry's GraphQLView to support DRF's token-based authentication.
     """
     graphiql_template = 'graphiql.html'
 
+    @csrf_exempt
     def dispatch(self, request, *args, **kwargs):
         config = get_config()
 

+ 2 - 3
netbox/netbox/urls.py

@@ -1,14 +1,13 @@
 from django.conf import settings
 from django.conf.urls import include
 from django.urls import path
-from django.views.decorators.csrf import csrf_exempt
 from django.views.static import serve
 from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
 
 from account.views import LoginView, LogoutView
 from netbox.api.views import APIRootView, StatusView
 from netbox.graphql.schema import schema
-from netbox.graphql.views import GraphQLView
+from netbox.graphql.views import NetBoxGraphQLView
 from netbox.plugins.urls import plugin_patterns, plugin_api_patterns
 from netbox.views import HomeView, StaticMediaFailureView, SearchView, htmx
 from strawberry.django.views import GraphQLView
@@ -61,7 +60,7 @@ _patterns = [
     path('api/schema/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='api_redocs'),
 
     # GraphQL
-    path('graphql/', GraphQLView.as_view(schema=schema), name='graphql'),
+    path('graphql/', NetBoxGraphQLView.as_view(schema=schema), name='graphql'),
 
     # Serving static media in Django to pipe it through LoginRequiredMiddleware
     path('media/<path:path>', serve, {'document_root': settings.MEDIA_ROOT}),

+ 18 - 25
netbox/utilities/testing/api.py

@@ -1,5 +1,6 @@
 import inspect
 import json
+import strawberry_django
 
 from django.conf import settings
 from django.contrib.auth import get_user_model
@@ -18,7 +19,7 @@ from .base import ModelTestCase
 from .utils import disable_warnings
 
 from ipam.graphql.types import IPAddressFamilyType
-
+from strawberry.type import StrawberryList
 
 __all__ = (
     'APITestCase',
@@ -447,36 +448,26 @@ class APIViewTestCases:
             # Compile list of fields to include
             fields_string = ''
 
-            for field_name, field in type_class.__dataclass_fields__.items():
+            for field in type_class.__strawberry_definition__.fields:
                 # for field_name, field in type_class._meta.fields.items():
-                print(f"field_name: {field_name} field: {field}")
-                is_string_array = False
-                if type(field.type) is GQLList:
-                    if field.type.of_type is GQLString:
-                        is_string_array = True
-                    elif type(field.type.of_type) is GQLNonNull and field.type.of_type.of_type is GQLString:
-                        is_string_array = True
-
-                if type(field) is GQLDynamic:
+                print(f"field_name: {field.name} type: {field.type}")
+
+                if type(field.type) is StrawberryList:
+                    fields_string += f'{field.name} {{ id }}\n'
+                elif field.type is strawberry_django.fields.types.DjangoModelType:
                     # Dynamic fields must specify a subselection
-                    fields_string += f'{field_name} {{ id }}\n'
+                    fields_string += f'{field.name} {{ id }}\n'
                 # TODO: Improve field detection logic to avoid nested ArrayFields
-                elif field_name == 'extra_choices':
-                    continue
-                elif inspect.isclass(field.type) and issubclass(field.type, GQLUnion):
-                    # Union types dont' have an id or consistent values
-                    continue
-                elif type(field.type) is GQLList and inspect.isclass(field.type.of_type) and issubclass(field.type.of_type, GQLUnion):
-                    # Union types dont' have an id or consistent values
+                elif field.name == 'extra_choices':
                     continue
-                elif type(field.type) is GQLList and not is_string_array:
-                    # TODO: Come up with something more elegant
-                    # Temporary hack to support automated testing of reverse generic relations
-                    fields_string += f'{field_name} {{ id }}\n'
+                # elif type(field.type) is GQLList and not is_string_array:
+                #     # TODO: Come up with something more elegant
+                #     # Temporary hack to support automated testing of reverse generic relations
+                #     fields_string += f'{field_name} {{ id }}\n'
                 elif inspect.isclass(field.type) and issubclass(field.type, IPAddressFamilyType):
-                    fields_string += f'{field_name} {{ value, label }}\n'
+                    fields_string += f'{field.name} {{ value, label }}\n'
                 else:
-                    fields_string += f'{field_name}\n'
+                    fields_string += f'{field.name}\n'
 
             query = f"""
             {{
@@ -486,6 +477,7 @@ class APIViewTestCases:
             }}
             """
 
+            print(query)
             return query
 
         @override_settings(LOGIN_REQUIRED=True)
@@ -498,6 +490,7 @@ class APIViewTestCases:
 
             # Non-authenticated requests should fail
             with disable_warnings('django.request'):
+                print(f"url: {url}")
                 self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
 
             # Add object-level permission