Explorar o código

Merge pull request #21522 from netbox-community/21356-etags

Closes #21356: Implement ETag support for REST API
bctiemann hai 1 día
pai
achega
cd5d88ff8a

+ 82 - 4
netbox/netbox/api/viewsets/__init__.py

@@ -13,7 +13,7 @@ from rest_framework.viewsets import GenericViewSet
 from netbox.api.serializers.features import ChangeLogMessageSerializer
 from netbox.constants import ADVISORY_LOCK_KEYS
 from utilities.api import get_annotations_for_serializer, get_prefetches_for_serializer
-from utilities.exceptions import AbortRequest
+from utilities.exceptions import AbortRequest, PreconditionFailed
 from utilities.query import reapply_model_ordering
 
 from . import mixins
@@ -34,6 +34,50 @@ HTTP_ACTIONS = {
 }
 
 
+class ETagMixin:
+    """
+    Adds ETag header support to ViewSets. Generates weak ETags (W/ prefix per
+    RFC 7232 §2.1) from `last_updated` (or `created` if unavailable). Weak ETags
+    are appropriate here because the tag is derived from a modification timestamp
+    rather than a hash of the serialized payload.
+    """
+
+    @staticmethod
+    def _get_etag(obj):
+        """Return a weak ETag string for the given object, or None."""
+        if ts := getattr(obj, 'last_updated', None) or getattr(obj, 'created', None):
+            return f'W/"{ts.isoformat()}"'
+        return None
+
+    @staticmethod
+    def _get_if_match(request):
+        """Return the list of If-Match header values (if specified)."""
+        if (if_match := request.META.get('HTTP_IF_MATCH')) and if_match != '*':
+            return [e.strip() for e in if_match.split(',')]
+        return []
+
+    def _validate_etag(self, request, instance):
+        """Validate the request's ETag"""
+        if provided := self._get_if_match(request):
+            current_etag = self._get_etag(instance)
+            if current_etag and current_etag not in provided:
+                raise PreconditionFailed(etag=current_etag)
+
+    def handle_exception(self, exc):
+        response = super().handle_exception(exc)
+        if isinstance(exc, PreconditionFailed) and exc.etag:
+            response['ETag'] = exc.etag
+        return response
+
+    def retrieve(self, request, *args, **kwargs):
+        instance = self.get_object()
+        serializer = self.get_serializer(instance)
+        response = Response(serializer.data)
+        if etag := self._get_etag(instance):
+            response['ETag'] = etag
+        return response
+
+
 class BaseViewSet(GenericViewSet):
     """
     Base class for all API ViewSets. This is responsible for the enforcement of object-based permissions.
@@ -95,6 +139,7 @@ class BaseViewSet(GenericViewSet):
 
 
 class NetBoxReadOnlyModelViewSet(
+    ETagMixin,
     mixins.CustomFieldsMixin,
     mixins.ExportTemplatesMixin,
     drf_mixins.RetrieveModelMixin,
@@ -105,6 +150,7 @@ class NetBoxReadOnlyModelViewSet(
 
 
 class NetBoxModelViewSet(
+    ETagMixin,
     mixins.BulkUpdateModelMixin,
     mixins.BulkDestroyModelMixin,
     mixins.ObjectValidationMixin,
@@ -191,7 +237,14 @@ class NetBoxModelViewSet(
         serializer = self.get_serializer(qs, many=bulk_create)
 
         headers = self.get_success_headers(serializer.data)
-        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
+        response = Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
+
+        # Add ETag for single-object creation only (bulk returns a list, no single ETag)
+        if not bulk_create:
+            if etag := self._get_etag(qs):
+                response['ETag'] = etag
+
+        return response
 
     def perform_create(self, serializer):
         model = self.queryset.model
@@ -211,6 +264,10 @@ class NetBoxModelViewSet(
     def update(self, request, *args, **kwargs):
         partial = kwargs.pop('partial', False)
         instance = self.get_object_with_snapshot()
+
+        # Enforce If-Match precondition (RFC 9110 §13.1.1)
+        self._validate_etag(self.request, instance)
+
         serializer = self.get_serializer(instance, data=request.data, partial=partial)
         serializer.is_valid(raise_exception=True)
         self.perform_update(serializer)
@@ -221,8 +278,12 @@ class NetBoxModelViewSet(
 
         # Re-serialize the instance(s) with prefetched data
         serializer = self.get_serializer(qs)
+        response = Response(serializer.data)
+
+        if etag := self._get_etag(qs):
+            response['ETag'] = etag
 
-        return Response(serializer.data)
+        return response
 
     def perform_update(self, serializer):
         model = self.queryset.model
@@ -232,6 +293,11 @@ class NetBoxModelViewSet(
         # Enforce object-level permissions on save()
         try:
             with transaction.atomic(using=router.db_for_write(model)):
+                # Re-check the If-Match ETag under a row-level lock to close the TOCTOU window
+                # between the initial check in update() and the actual write.
+                if self._get_if_match(self.request):
+                    locked = model.objects.select_for_update().get(pk=serializer.instance.pk)
+                    self._validate_etag(self.request, locked)
                 instance = serializer.save()
                 self._validate_objects(instance)
         except ObjectDoesNotExist:
@@ -242,6 +308,9 @@ class NetBoxModelViewSet(
     def destroy(self, request, *args, **kwargs):
         instance = self.get_object_with_snapshot()
 
+        # Enforce If-Match precondition (RFC 9110 §13.1.1)
+        self._validate_etag(request, instance)
+
         # Attach changelog message (if any)
         serializer = ChangeLogMessageSerializer(data=request.data)
         serializer.is_valid(raise_exception=True)
@@ -256,7 +325,16 @@ class NetBoxModelViewSet(
         logger = logging.getLogger(f'netbox.api.views.{self.__class__.__name__}')
         logger.info(f"Deleting {model._meta.verbose_name} {instance} (PK: {instance.pk})")
 
-        return super().perform_destroy(instance)
+        try:
+            with transaction.atomic(using=router.db_for_write(model)):
+                # Re-check the If-Match ETag under a row-level lock to close the TOCTOU window
+                # between the initial check in destroy() and the actual delete.
+                if self._get_if_match(self.request):
+                    locked = model.objects.select_for_update().get(pk=instance.pk)
+                    self._validate_etag(self.request, locked)
+                super().perform_destroy(instance)
+        except ObjectDoesNotExist:
+            raise PermissionDenied()
 
 
 class MPTTLockedMixin:

+ 15 - 0
netbox/utilities/exceptions.py

@@ -6,6 +6,7 @@ __all__ = (
     'AbortScript',
     'AbortTransaction',
     'PermissionsViolation',
+    'PreconditionFailed',
     'RQWorkerNotRunningException',
 )
 
@@ -40,6 +41,20 @@ class PermissionsViolation(Exception):
     message = "Operation failed due to object-level permissions violation"
 
 
+class PreconditionFailed(APIException):
+    """
+    Raised when an If-Match precondition is not satisfied (HTTP 412).
+    Optionally carries the current ETag so it can be included in the response.
+    """
+    status_code = status.HTTP_412_PRECONDITION_FAILED
+    default_detail = 'Precondition failed.'
+    default_code = 'precondition_failed'
+
+    def __init__(self, detail=None, code=None, etag=None):
+        super().__init__(detail=detail, code=code)
+        self.etag = etag
+
+
 class RQWorkerNotRunningException(APIException):
     """
     Indicates the temporary inability to enqueue a new task (e.g. custom script execution) because no RQ worker

+ 46 - 1
netbox/utilities/testing/api.py

@@ -114,7 +114,12 @@ class APIViewTestCases:
 
             # Try GET to permitted object
             url = self._get_detail_url(instance1)
-            self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_200_OK)
+            response = self.client.get(url, **self.header)
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+
+            # Verify ETag header is present for objects with timestamps
+            if issubclass(self.model, ChangeLoggingMixin):
+                self.assertIn('ETag', response, "ETag header missing from detail response")
 
             # Try GET to non-permitted object
             url = self._get_detail_url(instance2)
@@ -367,6 +372,46 @@ class APIViewTestCases:
                 self.assertEqual(objectchange.action, ObjectChangeActionChoices.ACTION_UPDATE)
                 self.assertEqual(objectchange.message, data['changelog_message'])
 
+        def test_update_object_with_etag(self):
+            """
+            PATCH an object using a valid If-Match ETag → expect 200.
+            PATCH again with the now-stale ETag → expect 412.
+            """
+            if not issubclass(self.model, ChangeLoggingMixin):
+                self.skipTest("Model does not support ETags")
+
+            self.add_permissions(
+                f'{self.model._meta.app_label}.view_{self.model._meta.model_name}',
+                f'{self.model._meta.app_label}.change_{self.model._meta.model_name}',
+            )
+            instance = self._get_queryset().first()
+            url = self._get_detail_url(instance)
+            update_data = self.update_data or getattr(self, 'create_data')[0]
+
+            # Fetch current ETag
+            get_response = self.client.get(url, **self.header)
+            self.assertHttpStatus(get_response, status.HTTP_200_OK)
+            etag = get_response.get('ETag')
+            self.assertIsNotNone(etag, "No ETag returned by GET")
+
+            # PATCH with correct ETag → 200
+            response = self.client.patch(
+                url, update_data, format='json',
+                **{**self.header, 'HTTP_IF_MATCH': etag}
+            )
+            self.assertHttpStatus(response, status.HTTP_200_OK)
+            new_etag = response.get('ETag')
+            self.assertIsNotNone(new_etag)
+            self.assertNotEqual(etag, new_etag)  # ETag must change after update
+
+            # PATCH with the old (stale) ETag → 412
+            with disable_warnings('django.request'):
+                response = self.client.patch(
+                    url, update_data, format='json',
+                    **{**self.header, 'HTTP_IF_MATCH': etag}
+                )
+            self.assertHttpStatus(response, status.HTTP_412_PRECONDITION_FAILED)
+
         def test_bulk_update_objects(self):
             """
             PATCH a set of objects in a single request.