Przeglądaj źródła

19644 Make atomic use correct database instead of default (#19651)

* 19644 set atomic transactions to appropriate database

* 19644 set atomic transactions for Job Script run

* 19644 set atomic transactions to appropriate database

* 19644 set atomic transactions to appropriate database

* 19644 fix review comments

* 19644 fix review comments
Arthur Hanson 7 miesięcy temu
rodzic
commit
a17699d261

+ 2 - 2
netbox/circuits/views.py

@@ -1,5 +1,5 @@
 from django.contrib import messages
-from django.db import transaction
+from django.db import router, transaction
 from django.shortcuts import get_object_or_404, redirect, render
 from django.utils.translation import gettext_lazy as _
 
@@ -384,7 +384,7 @@ class CircuitSwapTerminations(generic.ObjectEditView):
 
             if termination_a and termination_z:
                 # Use a placeholder to avoid an IntegrityError on the (circuit, term_side) unique constraint
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(CircuitTermination)):
                     termination_a.term_side = '_'
                     termination_a.save()
                     termination_z.term_side = 'A'

+ 2 - 2
netbox/dcim/utils.py

@@ -1,6 +1,6 @@
 from django.apps import apps
 from django.contrib.contenttypes.models import ContentType
-from django.db import transaction
+from django.db import router, transaction
 
 
 def compile_path_node(ct_id, object_id):
@@ -53,7 +53,7 @@ def rebuild_paths(terminations):
     for obj in terminations:
         cable_paths = CablePath.objects.filter(_nodes__contains=obj)
 
-        with transaction.atomic():
+        with transaction.atomic(using=router.db_for_write(CablePath)):
             for cp in cable_paths:
                 cp.delete()
                 create_cablepath(cp.origins)

+ 3 - 3
netbox/dcim/views.py

@@ -1,7 +1,7 @@
 from django.contrib import messages
 from django.contrib.contenttypes.models import ContentType
 from django.core.paginator import EmptyPage, PageNotAnInteger
-from django.db import transaction
+from django.db import router, transaction
 from django.db.models import Prefetch
 from django.forms import ModelMultipleChoiceField, MultipleHiddenInput, modelformset_factory
 from django.shortcuts import get_object_or_404, redirect, render
@@ -124,7 +124,7 @@ class BulkDisconnectView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View)
 
             if form.is_valid():
 
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(Cable)):
                     count = 0
                     cable_ids = set()
                     for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']):
@@ -3746,7 +3746,7 @@ class VirtualChassisEditView(ObjectPermissionRequiredMixin, GetReturnURLMixin, V
 
         if vc_form.is_valid() and formset.is_valid():
 
-            with transaction.atomic():
+            with transaction.atomic(using=router.db_for_write(Device)):
 
                 # Save the VirtualChassis
                 vc_form.save()

+ 3 - 0
netbox/extras/jobs.py

@@ -39,6 +39,9 @@ class ScriptJob(JobRunner):
 
         try:
             try:
+                # A script can modify multiple models so need to do an atomic lock on
+                # both the default database (for non ChangeLogged models) and potentially
+                # any other database (for ChangeLogged models)
                 with transaction.atomic():
                     script.output = script.run(data, commit)
                     if not commit:

+ 2 - 2
netbox/ipam/api/views.py

@@ -2,7 +2,7 @@ from copy import deepcopy
 
 from django.contrib.contenttypes.prefetch import GenericPrefetch
 from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
-from django.db import transaction
+from django.db import router, transaction
 from django.shortcuts import get_object_or_404
 from django.utils.translation import gettext as _
 from django_pglocks import advisory_lock
@@ -295,7 +295,7 @@ class AvailableObjectsView(ObjectValidationMixin, APIView):
 
             # Create the new IP address(es)
             try:
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(self.queryset.model)):
                     created = serializer.save()
                     self._validate_objects(created)
             except ObjectDoesNotExist:

+ 3 - 3
netbox/netbox/api/viewsets/__init__.py

@@ -2,7 +2,7 @@ import logging
 from functools import cached_property
 
 from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
-from django.db import transaction
+from django.db import router, transaction
 from django.db.models import ProtectedError, RestrictedError
 from django_pglocks import advisory_lock
 from netbox.constants import ADVISORY_LOCK_KEYS
@@ -170,7 +170,7 @@ class NetBoxModelViewSet(
 
         # Enforce object-level permissions on save()
         try:
-            with transaction.atomic():
+            with transaction.atomic(using=router.db_for_write(model)):
                 instance = serializer.save()
                 self._validate_objects(instance)
         except ObjectDoesNotExist:
@@ -190,7 +190,7 @@ class NetBoxModelViewSet(
 
         # Enforce object-level permissions on save()
         try:
-            with transaction.atomic():
+            with transaction.atomic(using=router.db_for_write(model)):
                 instance = serializer.save()
                 self._validate_objects(instance)
         except ObjectDoesNotExist:

+ 16 - 16
netbox/netbox/api/viewsets/mixins.py

@@ -1,5 +1,5 @@
 from django.core.exceptions import ObjectDoesNotExist
-from django.db import transaction
+from django.db import router, transaction
 from django.http import Http404
 from rest_framework import status
 from rest_framework.response import Response
@@ -56,22 +56,22 @@ class SequentialBulkCreatesMixin:
     which depends on the evaluation of existing objects (such as checking for free space within a rack) functions
     appropriately.
     """
-    @transaction.atomic
     def create(self, request, *args, **kwargs):
-        if not isinstance(request.data, list):
-            # Creating a single object
-            return super().create(request, *args, **kwargs)
-
-        return_data = []
-        for data in request.data:
-            serializer = self.get_serializer(data=data)
-            serializer.is_valid(raise_exception=True)
-            self.perform_create(serializer)
-            return_data.append(serializer.data)
+        with transaction.atomic(using=router.db_for_write(self.queryset.model)):
+            if not isinstance(request.data, list):
+                # Creating a single object
+                return super().create(request, *args, **kwargs)
+
+            return_data = []
+            for data in request.data:
+                serializer = self.get_serializer(data=data)
+                serializer.is_valid(raise_exception=True)
+                self.perform_create(serializer)
+                return_data.append(serializer.data)
 
-        headers = self.get_success_headers(serializer.data)
+            headers = self.get_success_headers(serializer.data)
 
-        return Response(return_data, status=status.HTTP_201_CREATED, headers=headers)
+            return Response(return_data, status=status.HTTP_201_CREATED, headers=headers)
 
 
 class BulkUpdateModelMixin:
@@ -113,7 +113,7 @@ class BulkUpdateModelMixin:
         return Response(data, status=status.HTTP_200_OK)
 
     def perform_bulk_update(self, objects, update_data, partial):
-        with transaction.atomic():
+        with transaction.atomic(using=router.db_for_write(self.queryset.model)):
             data_list = []
             for obj in objects:
                 data = update_data.get(obj.id)
@@ -157,7 +157,7 @@ class BulkDestroyModelMixin:
         return Response(status=status.HTTP_204_NO_CONTENT)
 
     def perform_bulk_destroy(self, objects):
-        with transaction.atomic():
+        with transaction.atomic(using=router.db_for_write(self.queryset.model)):
             for obj in objects:
                 if hasattr(obj, 'snapshot'):
                     obj.snapshot()

+ 7 - 7
netbox/netbox/views/generic/bulk_views.py

@@ -6,7 +6,7 @@ from django.contrib import messages
 from django.contrib.contenttypes.fields import GenericForeignKey, GenericRel
 from django.contrib.contenttypes.models import ContentType
 from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist, ValidationError
-from django.db import transaction, IntegrityError
+from django.db import IntegrityError, router, transaction
 from django.db.models import ManyToManyField, ProtectedError, RestrictedError
 from django.db.models.fields.reverse_related import ManyToManyRel
 from django.forms import ModelMultipleChoiceField, MultipleHiddenInput
@@ -278,7 +278,7 @@ class BulkCreateView(GetReturnURLMixin, BaseMultiObjectView):
             logger.debug("Form validation was successful")
 
             try:
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(model)):
                     new_objs = self._create_objects(form, request)
 
                     # Enforce object-level permissions
@@ -501,7 +501,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
 
             try:
                 # Iterate through data and bind each record to a new model form instance.
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(model)):
                     new_objs = self.create_and_update_objects(form, request)
 
                     # Enforce object-level permissions
@@ -681,7 +681,7 @@ class BulkEditView(GetReturnURLMixin, BaseMultiObjectView):
             if form.is_valid():
                 logger.debug("Form validation was successful")
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(model)):
                         updated_objects = self._update_objects(form, request)
 
                         # Enforce object-level permissions
@@ -778,7 +778,7 @@ class BulkRenameView(GetReturnURLMixin, BaseMultiObjectView):
 
             if form.is_valid():
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(self.queryset.model)):
                         renamed_pks = self._rename_objects(form, selected_objects)
 
                         if '_apply' in request.POST:
@@ -875,7 +875,7 @@ class BulkDeleteView(GetReturnURLMixin, BaseMultiObjectView):
                 queryset = self.queryset.filter(pk__in=pk_list)
                 deleted_count = queryset.count()
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(model)):
                         for obj in queryset:
                             # Take a snapshot of change-logged models
                             if hasattr(obj, 'snapshot'):
@@ -980,7 +980,7 @@ class BulkComponentCreateView(GetReturnURLMixin, BaseMultiObjectView):
                 }
 
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(self.queryset.model)):
 
                         for obj in data['pk']:
 

+ 2 - 2
netbox/netbox/views/generic/feature_views.py

@@ -1,7 +1,7 @@
 from django.contrib.auth.mixins import LoginRequiredMixin
 from django.contrib.contenttypes.models import ContentType
 from django.contrib import messages
-from django.db import transaction
+from django.db import router, transaction
 from django.db.models import Q
 from django.shortcuts import get_object_or_404, redirect, render
 from django.utils.translation import gettext_lazy as _
@@ -240,7 +240,7 @@ class BulkSyncDataView(GetReturnURLMixin, BaseMultiObjectView):
             data_file__isnull=False
         )
 
-        with transaction.atomic():
+        with transaction.atomic(using=router.db_for_write(self.queryset.model)):
             for obj in selected_objects:
                 obj.sync(save=True)
 

+ 2 - 2
netbox/netbox/views/generic/object_views.py

@@ -282,7 +282,7 @@ class ObjectEditView(GetReturnURLMixin, BaseObjectView):
             logger.debug("Form validation was successful")
 
             try:
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(model)):
                     object_created = form.instance.pk is None
                     obj = form.save()
 
@@ -570,7 +570,7 @@ class ComponentCreateView(GetReturnURLMixin, BaseObjectView):
 
             if not form.errors and not component_form.errors:
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(self.queryset.model)):
                         # Create the new components
                         new_objs = []
                         for component_form in new_components:

+ 3 - 3
netbox/virtualization/views.py

@@ -1,6 +1,6 @@
 from django.contrib import messages
 from django.contrib.contenttypes.models import ContentType
-from django.db import transaction
+from django.db import router, transaction
 from django.db.models import Prefetch, Sum
 from django.shortcuts import get_object_or_404, redirect, render
 from django.urls import reverse
@@ -297,7 +297,7 @@ class ClusterAddDevicesView(generic.ObjectEditView):
         if form.is_valid():
 
             device_pks = form.cleaned_data['devices']
-            with transaction.atomic():
+            with transaction.atomic(using=router.db_for_write(Device)):
 
                 # Assign the selected Devices to the Cluster
                 for device in Device.objects.filter(pk__in=device_pks):
@@ -332,7 +332,7 @@ class ClusterRemoveDevicesView(generic.ObjectEditView):
             if form.is_valid():
 
                 device_pks = form.cleaned_data['pk']
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(Device)):
 
                     # Remove the selected Devices from the Cluster
                     for device in Device.objects.filter(pk__in=device_pks):