瀏覽代碼

19644 Make atomic use correct database instead of default (#19651)

* 19644 set atomic transactions to appropriate database

* 19644 set atomic transactions for Job Script run

* 19644 set atomic transactions to appropriate database

* 19644 set atomic transactions to appropriate database

* 19644 fix review comments

* 19644 fix review comments
Arthur Hanson 7 月之前
父節點
當前提交
a17699d261

+ 2 - 2
netbox/circuits/views.py

@@ -1,5 +1,5 @@
 from django.contrib import messages
 from django.contrib import messages
-from django.db import transaction
+from django.db import router, transaction
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
 from django.utils.translation import gettext_lazy as _
 from django.utils.translation import gettext_lazy as _
 
 
@@ -384,7 +384,7 @@ class CircuitSwapTerminations(generic.ObjectEditView):
 
 
             if termination_a and termination_z:
             if termination_a and termination_z:
                 # Use a placeholder to avoid an IntegrityError on the (circuit, term_side) unique constraint
                 # Use a placeholder to avoid an IntegrityError on the (circuit, term_side) unique constraint
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(CircuitTermination)):
                     termination_a.term_side = '_'
                     termination_a.term_side = '_'
                     termination_a.save()
                     termination_a.save()
                     termination_z.term_side = 'A'
                     termination_z.term_side = 'A'

+ 2 - 2
netbox/dcim/utils.py

@@ -1,6 +1,6 @@
 from django.apps import apps
 from django.apps import apps
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
-from django.db import transaction
+from django.db import router, transaction
 
 
 
 
 def compile_path_node(ct_id, object_id):
 def compile_path_node(ct_id, object_id):
@@ -53,7 +53,7 @@ def rebuild_paths(terminations):
     for obj in terminations:
     for obj in terminations:
         cable_paths = CablePath.objects.filter(_nodes__contains=obj)
         cable_paths = CablePath.objects.filter(_nodes__contains=obj)
 
 
-        with transaction.atomic():
+        with transaction.atomic(using=router.db_for_write(CablePath)):
             for cp in cable_paths:
             for cp in cable_paths:
                 cp.delete()
                 cp.delete()
                 create_cablepath(cp.origins)
                 create_cablepath(cp.origins)

+ 3 - 3
netbox/dcim/views.py

@@ -1,7 +1,7 @@
 from django.contrib import messages
 from django.contrib import messages
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.core.paginator import EmptyPage, PageNotAnInteger
 from django.core.paginator import EmptyPage, PageNotAnInteger
-from django.db import transaction
+from django.db import router, transaction
 from django.db.models import Prefetch
 from django.db.models import Prefetch
 from django.forms import ModelMultipleChoiceField, MultipleHiddenInput, modelformset_factory
 from django.forms import ModelMultipleChoiceField, MultipleHiddenInput, modelformset_factory
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
@@ -124,7 +124,7 @@ class BulkDisconnectView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View)
 
 
             if form.is_valid():
             if form.is_valid():
 
 
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(Cable)):
                     count = 0
                     count = 0
                     cable_ids = set()
                     cable_ids = set()
                     for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']):
                     for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']):
@@ -3746,7 +3746,7 @@ class VirtualChassisEditView(ObjectPermissionRequiredMixin, GetReturnURLMixin, V
 
 
         if vc_form.is_valid() and formset.is_valid():
         if vc_form.is_valid() and formset.is_valid():
 
 
-            with transaction.atomic():
+            with transaction.atomic(using=router.db_for_write(Device)):
 
 
                 # Save the VirtualChassis
                 # Save the VirtualChassis
                 vc_form.save()
                 vc_form.save()

+ 3 - 0
netbox/extras/jobs.py

@@ -39,6 +39,9 @@ class ScriptJob(JobRunner):
 
 
         try:
         try:
             try:
             try:
+                # A script can modify multiple models so need to do an atomic lock on
+                # both the default database (for non ChangeLogged models) and potentially
+                # any other database (for ChangeLogged models)
                 with transaction.atomic():
                 with transaction.atomic():
                     script.output = script.run(data, commit)
                     script.output = script.run(data, commit)
                     if not commit:
                     if not commit:

+ 2 - 2
netbox/ipam/api/views.py

@@ -2,7 +2,7 @@ from copy import deepcopy
 
 
 from django.contrib.contenttypes.prefetch import GenericPrefetch
 from django.contrib.contenttypes.prefetch import GenericPrefetch
 from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
 from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
-from django.db import transaction
+from django.db import router, transaction
 from django.shortcuts import get_object_or_404
 from django.shortcuts import get_object_or_404
 from django.utils.translation import gettext as _
 from django.utils.translation import gettext as _
 from django_pglocks import advisory_lock
 from django_pglocks import advisory_lock
@@ -295,7 +295,7 @@ class AvailableObjectsView(ObjectValidationMixin, APIView):
 
 
             # Create the new IP address(es)
             # Create the new IP address(es)
             try:
             try:
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(self.queryset.model)):
                     created = serializer.save()
                     created = serializer.save()
                     self._validate_objects(created)
                     self._validate_objects(created)
             except ObjectDoesNotExist:
             except ObjectDoesNotExist:

+ 3 - 3
netbox/netbox/api/viewsets/__init__.py

@@ -2,7 +2,7 @@ import logging
 from functools import cached_property
 from functools import cached_property
 
 
 from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
 from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
-from django.db import transaction
+from django.db import router, transaction
 from django.db.models import ProtectedError, RestrictedError
 from django.db.models import ProtectedError, RestrictedError
 from django_pglocks import advisory_lock
 from django_pglocks import advisory_lock
 from netbox.constants import ADVISORY_LOCK_KEYS
 from netbox.constants import ADVISORY_LOCK_KEYS
@@ -170,7 +170,7 @@ class NetBoxModelViewSet(
 
 
         # Enforce object-level permissions on save()
         # Enforce object-level permissions on save()
         try:
         try:
-            with transaction.atomic():
+            with transaction.atomic(using=router.db_for_write(model)):
                 instance = serializer.save()
                 instance = serializer.save()
                 self._validate_objects(instance)
                 self._validate_objects(instance)
         except ObjectDoesNotExist:
         except ObjectDoesNotExist:
@@ -190,7 +190,7 @@ class NetBoxModelViewSet(
 
 
         # Enforce object-level permissions on save()
         # Enforce object-level permissions on save()
         try:
         try:
-            with transaction.atomic():
+            with transaction.atomic(using=router.db_for_write(model)):
                 instance = serializer.save()
                 instance = serializer.save()
                 self._validate_objects(instance)
                 self._validate_objects(instance)
         except ObjectDoesNotExist:
         except ObjectDoesNotExist:

+ 16 - 16
netbox/netbox/api/viewsets/mixins.py

@@ -1,5 +1,5 @@
 from django.core.exceptions import ObjectDoesNotExist
 from django.core.exceptions import ObjectDoesNotExist
-from django.db import transaction
+from django.db import router, transaction
 from django.http import Http404
 from django.http import Http404
 from rest_framework import status
 from rest_framework import status
 from rest_framework.response import Response
 from rest_framework.response import Response
@@ -56,22 +56,22 @@ class SequentialBulkCreatesMixin:
     which depends on the evaluation of existing objects (such as checking for free space within a rack) functions
     which depends on the evaluation of existing objects (such as checking for free space within a rack) functions
     appropriately.
     appropriately.
     """
     """
-    @transaction.atomic
     def create(self, request, *args, **kwargs):
     def create(self, request, *args, **kwargs):
-        if not isinstance(request.data, list):
-            # Creating a single object
-            return super().create(request, *args, **kwargs)
-
-        return_data = []
-        for data in request.data:
-            serializer = self.get_serializer(data=data)
-            serializer.is_valid(raise_exception=True)
-            self.perform_create(serializer)
-            return_data.append(serializer.data)
+        with transaction.atomic(using=router.db_for_write(self.queryset.model)):
+            if not isinstance(request.data, list):
+                # Creating a single object
+                return super().create(request, *args, **kwargs)
+
+            return_data = []
+            for data in request.data:
+                serializer = self.get_serializer(data=data)
+                serializer.is_valid(raise_exception=True)
+                self.perform_create(serializer)
+                return_data.append(serializer.data)
 
 
-        headers = self.get_success_headers(serializer.data)
+            headers = self.get_success_headers(serializer.data)
 
 
-        return Response(return_data, status=status.HTTP_201_CREATED, headers=headers)
+            return Response(return_data, status=status.HTTP_201_CREATED, headers=headers)
 
 
 
 
 class BulkUpdateModelMixin:
 class BulkUpdateModelMixin:
@@ -113,7 +113,7 @@ class BulkUpdateModelMixin:
         return Response(data, status=status.HTTP_200_OK)
         return Response(data, status=status.HTTP_200_OK)
 
 
     def perform_bulk_update(self, objects, update_data, partial):
     def perform_bulk_update(self, objects, update_data, partial):
-        with transaction.atomic():
+        with transaction.atomic(using=router.db_for_write(self.queryset.model)):
             data_list = []
             data_list = []
             for obj in objects:
             for obj in objects:
                 data = update_data.get(obj.id)
                 data = update_data.get(obj.id)
@@ -157,7 +157,7 @@ class BulkDestroyModelMixin:
         return Response(status=status.HTTP_204_NO_CONTENT)
         return Response(status=status.HTTP_204_NO_CONTENT)
 
 
     def perform_bulk_destroy(self, objects):
     def perform_bulk_destroy(self, objects):
-        with transaction.atomic():
+        with transaction.atomic(using=router.db_for_write(self.queryset.model)):
             for obj in objects:
             for obj in objects:
                 if hasattr(obj, 'snapshot'):
                 if hasattr(obj, 'snapshot'):
                     obj.snapshot()
                     obj.snapshot()

+ 7 - 7
netbox/netbox/views/generic/bulk_views.py

@@ -6,7 +6,7 @@ from django.contrib import messages
 from django.contrib.contenttypes.fields import GenericForeignKey, GenericRel
 from django.contrib.contenttypes.fields import GenericForeignKey, GenericRel
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist, ValidationError
 from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist, ValidationError
-from django.db import transaction, IntegrityError
+from django.db import IntegrityError, router, transaction
 from django.db.models import ManyToManyField, ProtectedError, RestrictedError
 from django.db.models import ManyToManyField, ProtectedError, RestrictedError
 from django.db.models.fields.reverse_related import ManyToManyRel
 from django.db.models.fields.reverse_related import ManyToManyRel
 from django.forms import ModelMultipleChoiceField, MultipleHiddenInput
 from django.forms import ModelMultipleChoiceField, MultipleHiddenInput
@@ -278,7 +278,7 @@ class BulkCreateView(GetReturnURLMixin, BaseMultiObjectView):
             logger.debug("Form validation was successful")
             logger.debug("Form validation was successful")
 
 
             try:
             try:
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(model)):
                     new_objs = self._create_objects(form, request)
                     new_objs = self._create_objects(form, request)
 
 
                     # Enforce object-level permissions
                     # Enforce object-level permissions
@@ -501,7 +501,7 @@ class BulkImportView(GetReturnURLMixin, BaseMultiObjectView):
 
 
             try:
             try:
                 # Iterate through data and bind each record to a new model form instance.
                 # Iterate through data and bind each record to a new model form instance.
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(model)):
                     new_objs = self.create_and_update_objects(form, request)
                     new_objs = self.create_and_update_objects(form, request)
 
 
                     # Enforce object-level permissions
                     # Enforce object-level permissions
@@ -681,7 +681,7 @@ class BulkEditView(GetReturnURLMixin, BaseMultiObjectView):
             if form.is_valid():
             if form.is_valid():
                 logger.debug("Form validation was successful")
                 logger.debug("Form validation was successful")
                 try:
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(model)):
                         updated_objects = self._update_objects(form, request)
                         updated_objects = self._update_objects(form, request)
 
 
                         # Enforce object-level permissions
                         # Enforce object-level permissions
@@ -778,7 +778,7 @@ class BulkRenameView(GetReturnURLMixin, BaseMultiObjectView):
 
 
             if form.is_valid():
             if form.is_valid():
                 try:
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(self.queryset.model)):
                         renamed_pks = self._rename_objects(form, selected_objects)
                         renamed_pks = self._rename_objects(form, selected_objects)
 
 
                         if '_apply' in request.POST:
                         if '_apply' in request.POST:
@@ -875,7 +875,7 @@ class BulkDeleteView(GetReturnURLMixin, BaseMultiObjectView):
                 queryset = self.queryset.filter(pk__in=pk_list)
                 queryset = self.queryset.filter(pk__in=pk_list)
                 deleted_count = queryset.count()
                 deleted_count = queryset.count()
                 try:
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(model)):
                         for obj in queryset:
                         for obj in queryset:
                             # Take a snapshot of change-logged models
                             # Take a snapshot of change-logged models
                             if hasattr(obj, 'snapshot'):
                             if hasattr(obj, 'snapshot'):
@@ -980,7 +980,7 @@ class BulkComponentCreateView(GetReturnURLMixin, BaseMultiObjectView):
                 }
                 }
 
 
                 try:
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(self.queryset.model)):
 
 
                         for obj in data['pk']:
                         for obj in data['pk']:
 
 

+ 2 - 2
netbox/netbox/views/generic/feature_views.py

@@ -1,7 +1,7 @@
 from django.contrib.auth.mixins import LoginRequiredMixin
 from django.contrib.auth.mixins import LoginRequiredMixin
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.contrib import messages
 from django.contrib import messages
-from django.db import transaction
+from django.db import router, transaction
 from django.db.models import Q
 from django.db.models import Q
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
 from django.utils.translation import gettext_lazy as _
 from django.utils.translation import gettext_lazy as _
@@ -240,7 +240,7 @@ class BulkSyncDataView(GetReturnURLMixin, BaseMultiObjectView):
             data_file__isnull=False
             data_file__isnull=False
         )
         )
 
 
-        with transaction.atomic():
+        with transaction.atomic(using=router.db_for_write(self.queryset.model)):
             for obj in selected_objects:
             for obj in selected_objects:
                 obj.sync(save=True)
                 obj.sync(save=True)
 
 

+ 2 - 2
netbox/netbox/views/generic/object_views.py

@@ -282,7 +282,7 @@ class ObjectEditView(GetReturnURLMixin, BaseObjectView):
             logger.debug("Form validation was successful")
             logger.debug("Form validation was successful")
 
 
             try:
             try:
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(model)):
                     object_created = form.instance.pk is None
                     object_created = form.instance.pk is None
                     obj = form.save()
                     obj = form.save()
 
 
@@ -570,7 +570,7 @@ class ComponentCreateView(GetReturnURLMixin, BaseObjectView):
 
 
             if not form.errors and not component_form.errors:
             if not form.errors and not component_form.errors:
                 try:
                 try:
-                    with transaction.atomic():
+                    with transaction.atomic(using=router.db_for_write(self.queryset.model)):
                         # Create the new components
                         # Create the new components
                         new_objs = []
                         new_objs = []
                         for component_form in new_components:
                         for component_form in new_components:

+ 3 - 3
netbox/virtualization/views.py

@@ -1,6 +1,6 @@
 from django.contrib import messages
 from django.contrib import messages
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
-from django.db import transaction
+from django.db import router, transaction
 from django.db.models import Prefetch, Sum
 from django.db.models import Prefetch, Sum
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
 from django.urls import reverse
 from django.urls import reverse
@@ -297,7 +297,7 @@ class ClusterAddDevicesView(generic.ObjectEditView):
         if form.is_valid():
         if form.is_valid():
 
 
             device_pks = form.cleaned_data['devices']
             device_pks = form.cleaned_data['devices']
-            with transaction.atomic():
+            with transaction.atomic(using=router.db_for_write(Device)):
 
 
                 # Assign the selected Devices to the Cluster
                 # Assign the selected Devices to the Cluster
                 for device in Device.objects.filter(pk__in=device_pks):
                 for device in Device.objects.filter(pk__in=device_pks):
@@ -332,7 +332,7 @@ class ClusterRemoveDevicesView(generic.ObjectEditView):
             if form.is_valid():
             if form.is_valid():
 
 
                 device_pks = form.cleaned_data['pk']
                 device_pks = form.cleaned_data['pk']
-                with transaction.atomic():
+                with transaction.atomic(using=router.db_for_write(Device)):
 
 
                     # Remove the selected Devices from the Cluster
                     # Remove the selected Devices from the Cluster
                     for device in Device.objects.filter(pk__in=device_pks):
                     for device in Device.objects.filter(pk__in=device_pks):