Arthur 1 год назад
Родитель
Сommit
14f04453bb

+ 1 - 1
netbox/dcim/graphql/types.py

@@ -541,7 +541,7 @@ class LocationType(VLANGroupsMixin, ImageAttachmentsMixin, ContactsMixin, Organi
         return self.vlan_groups.all()
         return self.vlan_groups.all()
 
 
     @strawberry_django.field
     @strawberry_django.field
-    def parent(self) -> Annotated["LocationType", strawberry.lazy('dcim.graphql.types')]:
+    def parent(self) -> Annotated["LocationType", strawberry.lazy('dcim.graphql.types')] | None:
         return self.parent
         return self.parent
 
 
     @strawberry_django.field
     @strawberry_django.field

+ 6 - 6
netbox/extras/graphql/mixins.py

@@ -24,13 +24,13 @@ if TYPE_CHECKING:
 class ChangelogMixin:
 class ChangelogMixin:
 
 
     @strawberry_django.field
     @strawberry_django.field
-    def changelog(self) -> List[Annotated["ObjectChangeType", strawberry.lazy('.types')]]:
+    def changelog(self, info) -> List[Annotated["ObjectChangeType", strawberry.lazy('.types')]]:
         content_type = ContentType.objects.get_for_model(self)
         content_type = ContentType.objects.get_for_model(self)
         object_changes = ObjectChange.objects.filter(
         object_changes = ObjectChange.objects.filter(
             changed_object_type=content_type,
             changed_object_type=content_type,
             changed_object_id=self.pk
             changed_object_id=self.pk
         )
         )
-        return object_changes.restrict(info.context.user, 'view')
+        return object_changes.restrict(info.context.request.user, 'view')
 
 
 
 
 @strawberry.type
 @strawberry.type
@@ -53,16 +53,16 @@ class CustomFieldsMixin:
 class ImageAttachmentsMixin:
 class ImageAttachmentsMixin:
 
 
     @strawberry_django.field
     @strawberry_django.field
-    def image_attachments(self) -> List[Annotated["ImageAttachmentType", strawberry.lazy('.types')]]:
-        return self.images.restrict(info.context.user, 'view')
+    def image_attachments(self, info) -> List[Annotated["ImageAttachmentType", strawberry.lazy('.types')]]:
+        return self.images.restrict(info.context.request.user, 'view')
 
 
 
 
 @strawberry.type
 @strawberry.type
 class JournalEntriesMixin:
 class JournalEntriesMixin:
 
 
     @strawberry_django.field
     @strawberry_django.field
-    def journal_entries(self) -> List[Annotated["JournalEntryType", strawberry.lazy('.types')]]:
-        return self.journal_entries.restrict(info.context.user, 'view')
+    def journal_entries(self, info) -> List[Annotated["JournalEntryType", strawberry.lazy('.types')]]:
+        return self.journal_entries.restrict(info.context.request.user, 'view')
 
 
 
 
 @strawberry.type
 @strawberry.type

+ 2 - 2
netbox/ipam/graphql/mixins.py

@@ -10,11 +10,11 @@ class IPAddressesMixin:
     ip_addresses = graphene.List('ipam.graphql.types.IPAddressType')
     ip_addresses = graphene.List('ipam.graphql.types.IPAddressType')
 
 
     def resolve_ip_addresses(self, info):
     def resolve_ip_addresses(self, info):
-        return self.ip_addresses.restrict(info.context.user, 'view')
+        return self.ip_addresses.restrict(info.context.request.user, 'view')
 
 
 
 
 class VLANGroupsMixin:
 class VLANGroupsMixin:
     vlan_groups = graphene.List('ipam.graphql.types.VLANGroupType')
     vlan_groups = graphene.List('ipam.graphql.types.VLANGroupType')
 
 
     def resolve_vlan_groups(self, info):
     def resolve_vlan_groups(self, info):
-        return self.vlan_groups.restrict(info.context.user, 'view')
+        return self.vlan_groups.restrict(info.context.request.user, 'view')

+ 5 - 3
netbox/netbox/graphql/views.py

@@ -2,19 +2,21 @@ from django.conf import settings
 from django.contrib.auth.views import redirect_to_login
 from django.contrib.auth.views import redirect_to_login
 from django.http import HttpResponseNotFound, HttpResponseForbidden
 from django.http import HttpResponseNotFound, HttpResponseForbidden
 from django.urls import reverse
 from django.urls import reverse
-from graphene_django.views import GraphQLView as GraphQLView_
+from django.views.decorators.csrf import csrf_exempt
 from rest_framework.exceptions import AuthenticationFailed
 from rest_framework.exceptions import AuthenticationFailed
+from strawberry.django.views import GraphQLView
 
 
 from netbox.api.authentication import TokenAuthentication
 from netbox.api.authentication import TokenAuthentication
 from netbox.config import get_config
 from netbox.config import get_config
 
 
 
 
-class GraphQLView(GraphQLView_):
+class NetBoxGraphQLView(GraphQLView):
     """
     """
-    Extends graphene_django's GraphQLView to support DRF's token-based authentication.
+    Extends strawberry's GraphQLView to support DRF's token-based authentication.
     """
     """
     graphiql_template = 'graphiql.html'
     graphiql_template = 'graphiql.html'
 
 
+    @csrf_exempt
     def dispatch(self, request, *args, **kwargs):
     def dispatch(self, request, *args, **kwargs):
         config = get_config()
         config = get_config()
 
 

+ 2 - 3
netbox/netbox/urls.py

@@ -1,14 +1,13 @@
 from django.conf import settings
 from django.conf import settings
 from django.conf.urls import include
 from django.conf.urls import include
 from django.urls import path
 from django.urls import path
-from django.views.decorators.csrf import csrf_exempt
 from django.views.static import serve
 from django.views.static import serve
 from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
 from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
 
 
 from account.views import LoginView, LogoutView
 from account.views import LoginView, LogoutView
 from netbox.api.views import APIRootView, StatusView
 from netbox.api.views import APIRootView, StatusView
 from netbox.graphql.schema import schema
 from netbox.graphql.schema import schema
-from netbox.graphql.views import GraphQLView
+from netbox.graphql.views import NetBoxGraphQLView
 from netbox.plugins.urls import plugin_patterns, plugin_api_patterns
 from netbox.plugins.urls import plugin_patterns, plugin_api_patterns
 from netbox.views import HomeView, StaticMediaFailureView, SearchView, htmx
 from netbox.views import HomeView, StaticMediaFailureView, SearchView, htmx
 from strawberry.django.views import GraphQLView
 from strawberry.django.views import GraphQLView
@@ -61,7 +60,7 @@ _patterns = [
     path('api/schema/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='api_redocs'),
     path('api/schema/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='api_redocs'),
 
 
     # GraphQL
     # GraphQL
-    path('graphql/', GraphQLView.as_view(schema=schema), name='graphql'),
+    path('graphql/', NetBoxGraphQLView.as_view(schema=schema), name='graphql'),
 
 
     # Serving static media in Django to pipe it through LoginRequiredMiddleware
     # Serving static media in Django to pipe it through LoginRequiredMiddleware
     path('media/<path:path>', serve, {'document_root': settings.MEDIA_ROOT}),
     path('media/<path:path>', serve, {'document_root': settings.MEDIA_ROOT}),

+ 18 - 25
netbox/utilities/testing/api.py

@@ -1,5 +1,6 @@
 import inspect
 import inspect
 import json
 import json
+import strawberry_django
 
 
 from django.conf import settings
 from django.conf import settings
 from django.contrib.auth import get_user_model
 from django.contrib.auth import get_user_model
@@ -18,7 +19,7 @@ from .base import ModelTestCase
 from .utils import disable_warnings
 from .utils import disable_warnings
 
 
 from ipam.graphql.types import IPAddressFamilyType
 from ipam.graphql.types import IPAddressFamilyType
-
+from strawberry.type import StrawberryList
 
 
 __all__ = (
 __all__ = (
     'APITestCase',
     'APITestCase',
@@ -447,36 +448,26 @@ class APIViewTestCases:
             # Compile list of fields to include
             # Compile list of fields to include
             fields_string = ''
             fields_string = ''
 
 
-            for field_name, field in type_class.__dataclass_fields__.items():
+            for field in type_class.__strawberry_definition__.fields:
                 # for field_name, field in type_class._meta.fields.items():
                 # for field_name, field in type_class._meta.fields.items():
-                print(f"field_name: {field_name} field: {field}")
-                is_string_array = False
-                if type(field.type) is GQLList:
-                    if field.type.of_type is GQLString:
-                        is_string_array = True
-                    elif type(field.type.of_type) is GQLNonNull and field.type.of_type.of_type is GQLString:
-                        is_string_array = True
-
-                if type(field) is GQLDynamic:
+                print(f"field_name: {field.name} type: {field.type}")
+
+                if type(field.type) is StrawberryList:
+                    fields_string += f'{field.name} {{ id }}\n'
+                elif field.type is strawberry_django.fields.types.DjangoModelType:
                     # Dynamic fields must specify a subselection
                     # Dynamic fields must specify a subselection
-                    fields_string += f'{field_name} {{ id }}\n'
+                    fields_string += f'{field.name} {{ id }}\n'
                 # TODO: Improve field detection logic to avoid nested ArrayFields
                 # TODO: Improve field detection logic to avoid nested ArrayFields
-                elif field_name == 'extra_choices':
-                    continue
-                elif inspect.isclass(field.type) and issubclass(field.type, GQLUnion):
-                    # Union types dont' have an id or consistent values
-                    continue
-                elif type(field.type) is GQLList and inspect.isclass(field.type.of_type) and issubclass(field.type.of_type, GQLUnion):
-                    # Union types dont' have an id or consistent values
+                elif field.name == 'extra_choices':
                     continue
                     continue
-                elif type(field.type) is GQLList and not is_string_array:
-                    # TODO: Come up with something more elegant
-                    # Temporary hack to support automated testing of reverse generic relations
-                    fields_string += f'{field_name} {{ id }}\n'
+                # elif type(field.type) is GQLList and not is_string_array:
+                #     # TODO: Come up with something more elegant
+                #     # Temporary hack to support automated testing of reverse generic relations
+                #     fields_string += f'{field_name} {{ id }}\n'
                 elif inspect.isclass(field.type) and issubclass(field.type, IPAddressFamilyType):
                 elif inspect.isclass(field.type) and issubclass(field.type, IPAddressFamilyType):
-                    fields_string += f'{field_name} {{ value, label }}\n'
+                    fields_string += f'{field.name} {{ value, label }}\n'
                 else:
                 else:
-                    fields_string += f'{field_name}\n'
+                    fields_string += f'{field.name}\n'
 
 
             query = f"""
             query = f"""
             {{
             {{
@@ -486,6 +477,7 @@ class APIViewTestCases:
             }}
             }}
             """
             """
 
 
+            print(query)
             return query
             return query
 
 
         @override_settings(LOGIN_REQUIRED=True)
         @override_settings(LOGIN_REQUIRED=True)
@@ -498,6 +490,7 @@ class APIViewTestCases:
 
 
             # Non-authenticated requests should fail
             # Non-authenticated requests should fail
             with disable_warnings('django.request'):
             with disable_warnings('django.request'):
+                print(f"url: {url}")
                 self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
                 self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
 
 
             # Add object-level permission
             # Add object-level permission