__init__.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import logging
  2. from django.contrib.contenttypes.models import ContentType
  3. from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
  4. from django.db import transaction
  5. from django.db.models import ProtectedError
  6. from django.shortcuts import get_object_or_404
  7. from rest_framework.response import Response
  8. from rest_framework.viewsets import ModelViewSet
  9. from extras.models import ExportTemplate
  10. from netbox.api.exceptions import SerializerNotFound
  11. from netbox.constants import NESTED_SERIALIZER_PREFIX
  12. from utilities.api import get_serializer_for_model
  13. from .mixins import *
  14. __all__ = (
  15. 'NetBoxModelViewSet',
  16. )
  17. HTTP_ACTIONS = {
  18. 'GET': 'view',
  19. 'OPTIONS': None,
  20. 'HEAD': 'view',
  21. 'POST': 'add',
  22. 'PUT': 'change',
  23. 'PATCH': 'change',
  24. 'DELETE': 'delete',
  25. }
  26. class NetBoxModelViewSet(BulkUpdateModelMixin, BulkDestroyModelMixin, ObjectValidationMixin, ModelViewSet):
  27. """
  28. Extend DRF's ModelViewSet to support bulk update and delete functions.
  29. """
  30. brief = False
  31. brief_prefetch_fields = []
  32. def get_object_with_snapshot(self):
  33. """
  34. Save a pre-change snapshot of the object immediately after retrieving it. This snapshot will be used to
  35. record the "before" data in the changelog.
  36. """
  37. obj = super().get_object()
  38. if hasattr(obj, 'snapshot'):
  39. obj.snapshot()
  40. return obj
  41. def get_serializer(self, *args, **kwargs):
  42. # If a list of objects has been provided, initialize the serializer with many=True
  43. if isinstance(kwargs.get('data', {}), list):
  44. kwargs['many'] = True
  45. return super().get_serializer(*args, **kwargs)
  46. def get_serializer_class(self):
  47. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  48. # If using 'brief' mode, find and return the nested serializer for this model, if one exists
  49. if self.brief:
  50. logger.debug("Request is for 'brief' format; initializing nested serializer")
  51. try:
  52. serializer = get_serializer_for_model(self.queryset.model, prefix=NESTED_SERIALIZER_PREFIX)
  53. logger.debug(f"Using serializer {serializer}")
  54. return serializer
  55. except SerializerNotFound:
  56. logger.debug(f"Nested serializer for {self.queryset.model} not found!")
  57. # Fall back to the hard-coded serializer class
  58. logger.debug(f"Using serializer {self.serializer_class}")
  59. return self.serializer_class
  60. def get_serializer_context(self):
  61. """
  62. For models which support custom fields, populate the `custom_fields` context.
  63. """
  64. context = super().get_serializer_context()
  65. if hasattr(self.queryset.model, 'custom_fields'):
  66. content_type = ContentType.objects.get_for_model(self.queryset.model)
  67. context.update({
  68. 'custom_fields': content_type.custom_fields.all(),
  69. })
  70. return context
  71. def get_queryset(self):
  72. # If using brief mode, clear all prefetches from the queryset and append only brief_prefetch_fields (if any)
  73. if self.brief:
  74. return super().get_queryset().prefetch_related(None).prefetch_related(*self.brief_prefetch_fields)
  75. return super().get_queryset()
  76. def initialize_request(self, request, *args, **kwargs):
  77. # Check if brief=True has been passed
  78. if request.method == 'GET' and request.GET.get('brief'):
  79. self.brief = True
  80. return super().initialize_request(request, *args, **kwargs)
  81. def initial(self, request, *args, **kwargs):
  82. super().initial(request, *args, **kwargs)
  83. if not request.user.is_authenticated:
  84. return
  85. # Restrict the view's QuerySet to allow only the permitted objects
  86. action = HTTP_ACTIONS[request.method]
  87. if action:
  88. self.queryset = self.queryset.restrict(request.user, action)
  89. def dispatch(self, request, *args, **kwargs):
  90. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  91. try:
  92. return super().dispatch(request, *args, **kwargs)
  93. except ProtectedError as e:
  94. protected_objects = list(e.protected_objects)
  95. msg = f'Unable to delete object. {len(protected_objects)} dependent objects were found: '
  96. msg += ', '.join([f'{obj} ({obj.pk})' for obj in protected_objects])
  97. logger.warning(msg)
  98. return self.finalize_response(
  99. request,
  100. Response({'detail': msg}, status=409),
  101. *args,
  102. **kwargs
  103. )
  104. def list(self, request, *args, **kwargs):
  105. """
  106. Overrides ListModelMixin to allow processing ExportTemplates.
  107. """
  108. if 'export' in request.GET:
  109. content_type = ContentType.objects.get_for_model(self.get_serializer_class().Meta.model)
  110. et = get_object_or_404(ExportTemplate, content_type=content_type, name=request.GET['export'])
  111. queryset = self.filter_queryset(self.get_queryset())
  112. return et.render_to_response(queryset)
  113. return super().list(request, *args, **kwargs)
  114. def perform_create(self, serializer):
  115. model = self.queryset.model
  116. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  117. logger.info(f"Creating new {model._meta.verbose_name}")
  118. # Enforce object-level permissions on save()
  119. try:
  120. with transaction.atomic():
  121. instance = serializer.save()
  122. self._validate_objects(instance)
  123. except ObjectDoesNotExist:
  124. raise PermissionDenied()
  125. def update(self, request, *args, **kwargs):
  126. # Hotwire get_object() to ensure we save a pre-change snapshot
  127. self.get_object = self.get_object_with_snapshot
  128. return super().update(request, *args, **kwargs)
  129. def perform_update(self, serializer):
  130. model = self.queryset.model
  131. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  132. logger.info(f"Updating {model._meta.verbose_name} {serializer.instance} (PK: {serializer.instance.pk})")
  133. # Enforce object-level permissions on save()
  134. try:
  135. with transaction.atomic():
  136. instance = serializer.save()
  137. self._validate_objects(instance)
  138. except ObjectDoesNotExist:
  139. raise PermissionDenied()
  140. def destroy(self, request, *args, **kwargs):
  141. # Hotwire get_object() to ensure we save a pre-change snapshot
  142. self.get_object = self.get_object_with_snapshot
  143. return super().destroy(request, *args, **kwargs)
  144. def perform_destroy(self, instance):
  145. model = self.queryset.model
  146. logger = logging.getLogger('netbox.api.views.ModelViewSet')
  147. logger.info(f"Deleting {model._meta.verbose_name} {instance} (PK: {instance.pk})")
  148. return super().perform_destroy(instance)