Просмотр исходного кода

Enforce object-level permissions for API views

Jeremy Stretch 5 лет назад
Родитель
Сommit
aeb32104a4
1 измененных файлов с 18 добавлено и 2 удалено
  1. 18 2
      netbox/utilities/api.py

+ 18 - 2
netbox/utilities/api.py

@@ -6,15 +6,15 @@ from django.conf import settings
 from django.contrib.contenttypes.models import ContentType
 from django.core.exceptions import FieldError, MultipleObjectsReturned, ObjectDoesNotExist
 from django.db.models import ManyToManyField, ProtectedError
-from django.http import Http404
 from django.urls import reverse
 from rest_framework.exceptions import APIException
 from rest_framework.permissions import BasePermission
 from rest_framework.relations import PrimaryKeyRelatedField, RelatedField
 from rest_framework.response import Response
 from rest_framework.serializers import Field, ModelSerializer, ValidationError
-from rest_framework.viewsets import ModelViewSet as _ModelViewSet, ViewSet
+from rest_framework.viewsets import ModelViewSet as _ModelViewSet
 
+from users.models import ObjectPermission
 from .utils import dict_to_filter_params, dynamic_import
 
 
@@ -323,6 +323,22 @@ class ModelViewSet(_ModelViewSet):
         logger.debug(f"Using serializer {self.serializer_class}")
         return self.serializer_class
 
+    def initial(self, request, *args, **kwargs):
+        super().initial(request, *args, **kwargs)
+
+        if not request.user.is_authenticated or request.user.is_superuser:
+            return
+
+        permission_required = 'dcim.view_site'
+
+        # Enforce object-level permissions
+        if permission_required not in self.request.user._perm_cache:
+            attrs = ObjectPermission.objects.get_attr_constraints(self.request.user, permission_required)
+            if attrs:
+                # Update the view's QuerySet to filter only the permitted objects
+                self.queryset = self.queryset.filter(attrs)
+                return True
+
     def dispatch(self, request, *args, **kwargs):
         logger = logging.getLogger('netbox.api.views.ModelViewSet')