Просмотр исходного кода

Add support for DRF token authentication

jeremystretch 4 лет назад
Родитель
Сommit
d5675a5d4a

+ 20 - 0
netbox/netbox/graphql/views.py

@@ -0,0 +1,20 @@
+from graphene_django.views import GraphQLView as GraphQLView_
+from rest_framework.decorators import authentication_classes, permission_classes, api_view
+from rest_framework.permissions import IsAuthenticated
+from rest_framework.settings import api_settings
+
+
+class GraphQLView(GraphQLView_):
+    """
+    Extends grpahene_django's GraphQLView to support DRF's token-based authentication.
+    """
+    @classmethod
+    def as_view(cls, *args, **kwargs):
+        view = super(GraphQLView, cls).as_view(*args, **kwargs)
+
+        # Apply DRF permission and authentication classes
+        view = permission_classes((IsAuthenticated,))(view)
+        view = authentication_classes(api_settings.DEFAULT_AUTHENTICATION_CLASSES)(view)
+        view = api_view(['GET', 'POST'])(view)
+
+        return view

+ 2 - 1
netbox/netbox/middleware.py

@@ -24,7 +24,8 @@ class LoginRequiredMiddleware(object):
         if settings.LOGIN_REQUIRED and not request.user.is_authenticated:
             # Determine exempt paths
             exempt_paths = [
-                reverse('api-root')
+                reverse('api-root'),
+                reverse('graphql'),
             ]
             if settings.METRICS_ENABLED:
                 exempt_paths.append(reverse('prometheus-django-metrics'))

+ 1 - 1
netbox/netbox/urls.py

@@ -4,11 +4,11 @@ from django.urls import path, re_path
 from django.views.static import serve
 from drf_yasg import openapi
 from drf_yasg.views import get_schema_view
-from graphene_django.views import GraphQLView
 
 from extras.plugins.urls import plugin_admin_patterns, plugin_patterns, plugin_api_patterns
 from netbox.api.views import APIRootView, StatusView
 from netbox.graphql.schema import schema
+from netbox.graphql.views import GraphQLView
 from netbox.views import HomeView, StaticMediaFailureView, SearchView
 from users.views import LoginView, LogoutView
 from .admin import admin_site

+ 21 - 0
netbox/utilities/testing/api.py

@@ -425,6 +425,7 @@ class APIViewTestCases:
 
     class GraphQLTestCase(APITestCase):
 
+        @override_settings(LOGIN_REQUIRED=True)
         def test_graphql_get_object(self):
             url = reverse('graphql')
             object_type = self.model._meta.verbose_name.replace(' ', '_')
@@ -441,11 +442,21 @@ class APIViewTestCases:
             with disable_warnings('django.request'):
                 self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
 
+            # Add object-level permission
+            obj_perm = ObjectPermission(
+                name='Test permission',
+                actions=['view']
+            )
+            obj_perm.save()
+            obj_perm.users.add(self.user)
+            obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
+
             response = self.client.post(url, data={'query': query}, **self.header)
             self.assertHttpStatus(response, status.HTTP_200_OK)
             data = json.loads(response.content)
             self.assertNotIn('errors', data)
 
+        @override_settings(LOGIN_REQUIRED=True)
         def test_graphql_list_objects(self):
             url = reverse('graphql')
             object_type = self.model._meta.verbose_name_plural.replace(' ', '_')
@@ -461,10 +472,20 @@ class APIViewTestCases:
             with disable_warnings('django.request'):
                 self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
 
+            # Add object-level permission
+            obj_perm = ObjectPermission(
+                name='Test permission',
+                actions=['view']
+            )
+            obj_perm.save()
+            obj_perm.users.add(self.user)
+            obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
+
             response = self.client.post(url, data={'query': query}, **self.header)
             self.assertHttpStatus(response, status.HTTP_200_OK)
             data = json.loads(response.content)
             self.assertNotIn('errors', data)
+            self.assertGreater(len(data['data'][object_type]), 0)
 
     class APIViewTestCase(
         GetObjectViewTestCase,