Quellcode durchsuchen

Refactor bulk generic views

jeremystretch vor 4 Jahren
Ursprung
Commit
e91a76c936
1 geänderte Dateien mit 191 neuen und 171 gelöschten Zeilen
  1. 191 171
      netbox/netbox/views/generic.py

+ 191 - 171
netbox/netbox/views/generic.py

@@ -539,6 +539,31 @@ class BulkCreateView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
     def get_required_permission(self):
     def get_required_permission(self):
         return get_permission_for_model(self.queryset.model, 'add')
         return get_permission_for_model(self.queryset.model, 'add')
 
 
+    def _create_objects(self, form, request):
+        new_objects = []
+
+        # Create objects from the expanded. Abort the transaction on the first validation error.
+        for value in form.cleaned_data['pattern']:
+
+            # Reinstantiate the model form each time to avoid overwriting the same instance. Use a mutable
+            # copy of the POST QueryDict so that we can update the target field value.
+            model_form = self.model_form(request.POST.copy())
+            model_form.data[self.pattern_target] = value
+
+            # Validate each new object independently.
+            if model_form.is_valid():
+                obj = model_form.save()
+                new_objects.append(obj)
+            else:
+                # Copy any errors on the pattern target field to the pattern form.
+                errors = model_form.errors.as_data()
+                if errors.get(self.pattern_target):
+                    form.add_error('pattern', errors[self.pattern_target])
+                # Raise an IntegrityError to break the for loop and abort the transaction.
+                raise IntegrityError()
+
+        return new_objects
+
     def get(self, request):
     def get(self, request):
         # Set initial values for visible form fields from query args
         # Set initial values for visible form fields from query args
         initial = {}
         initial = {}
@@ -564,45 +589,23 @@ class BulkCreateView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
 
 
         if form.is_valid():
         if form.is_valid():
             logger.debug("Form validation was successful")
             logger.debug("Form validation was successful")
-            pattern = form.cleaned_data['pattern']
-            new_objs = []
 
 
             try:
             try:
                 with transaction.atomic():
                 with transaction.atomic():
-
-                    # Create objects from the expanded. Abort the transaction on the first validation error.
-                    for value in pattern:
-
-                        # Reinstantiate the model form each time to avoid overwriting the same instance. Use a mutable
-                        # copy of the POST QueryDict so that we can update the target field value.
-                        model_form = self.model_form(request.POST.copy())
-                        model_form.data[self.pattern_target] = value
-
-                        # Validate each new object independently.
-                        if model_form.is_valid():
-                            obj = model_form.save()
-                            logger.debug(f"Created {obj} (PK: {obj.pk})")
-                            new_objs.append(obj)
-                        else:
-                            # Copy any errors on the pattern target field to the pattern form.
-                            errors = model_form.errors.as_data()
-                            if errors.get(self.pattern_target):
-                                form.add_error('pattern', errors[self.pattern_target])
-                            # Raise an IntegrityError to break the for loop and abort the transaction.
-                            raise IntegrityError()
+                    new_objs = self._create_objects(form, request)
 
 
                     # Enforce object-level permissions
                     # Enforce object-level permissions
                     if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):
                     if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):
                         raise PermissionsViolation
                         raise PermissionsViolation
 
 
-                    # If we make it to this point, validation has succeeded on all new objects.
-                    msg = "Added {} {}".format(len(new_objs), model._meta.verbose_name_plural)
-                    logger.info(msg)
-                    messages.success(request, msg)
+                # If we make it to this point, validation has succeeded on all new objects.
+                msg = f"Added {len(new_objs)} {model._meta.verbose_name_plural}"
+                logger.info(msg)
+                messages.success(request, msg)
 
 
-                    if '_addanother' in request.POST:
-                        return redirect(request.path)
-                    return redirect(self.get_return_url(request))
+                if '_addanother' in request.POST:
+                    return redirect(request.path)
+                return redirect(self.get_return_url(request))
 
 
             except IntegrityError:
             except IntegrityError:
                 pass
                 pass
@@ -640,6 +643,45 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
     def get_required_permission(self):
     def get_required_permission(self):
         return get_permission_for_model(self.queryset.model, 'add')
         return get_permission_for_model(self.queryset.model, 'add')
 
 
+    def _create_object(self, model_form):
+
+        # Save the primary object
+        obj = model_form.save()
+
+        # Enforce object-level permissions
+        if not self.queryset.filter(pk=obj.pk).first():
+            raise PermissionsViolation()
+
+        # Iterate through the related object forms (if any), validating and saving each instance.
+        for field_name, related_object_form in self.related_object_forms.items():
+
+            related_obj_pks = []
+            for i, rel_obj_data in enumerate(model_form.data.get(field_name, list())):
+
+                f = related_object_form(obj, rel_obj_data)
+
+                for subfield_name, field in f.fields.items():
+                    if subfield_name not in rel_obj_data and hasattr(field, 'initial'):
+                        f.data[subfield_name] = field.initial
+
+                if f.is_valid():
+                    related_obj = f.save()
+                    related_obj_pks.append(related_obj.pk)
+                else:
+                    # Replicate errors on the related object form to the primary form for display
+                    for subfield_name, errors in f.errors.items():
+                        for err in errors:
+                            err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err)
+                            model_form.add_error(None, err_msg)
+                    raise AbortTransaction()
+
+            # Enforce object-level permissions on related objects
+            model = related_object_form.Meta.model
+            if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks):
+                raise ObjectDoesNotExist
+
+        return obj
+
     def get(self, request):
     def get(self, request):
         form = ImportForm()
         form = ImportForm()
 
 
@@ -673,44 +715,7 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
 
 
                 try:
                 try:
                     with transaction.atomic():
                     with transaction.atomic():
-
-                        # Save the primary object
-                        obj = model_form.save()
-
-                        # Enforce object-level permissions
-                        if not self.queryset.filter(pk=obj.pk).first():
-                            raise PermissionsViolation()
-
-                        logger.debug(f"Created {obj} (PK: {obj.pk})")
-
-                        # Iterate through the related object forms (if any), validating and saving each instance.
-                        for field_name, related_object_form in self.related_object_forms.items():
-                            logger.debug("Processing form for related objects: {related_object_form}")
-
-                            related_obj_pks = []
-                            for i, rel_obj_data in enumerate(data.get(field_name, list())):
-
-                                f = related_object_form(obj, rel_obj_data)
-
-                                for subfield_name, field in f.fields.items():
-                                    if subfield_name not in rel_obj_data and hasattr(field, 'initial'):
-                                        f.data[subfield_name] = field.initial
-
-                                if f.is_valid():
-                                    related_obj = f.save()
-                                    related_obj_pks.append(related_obj.pk)
-                                else:
-                                    # Replicate errors on the related object form to the primary form for display
-                                    for subfield_name, errors in f.errors.items():
-                                        for err in errors:
-                                            err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err)
-                                            model_form.add_error(None, err_msg)
-                                    raise AbortTransaction()
-
-                            # Enforce object-level permissions on related objects
-                            model = related_object_form.Meta.model
-                            if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks):
-                                raise ObjectDoesNotExist
+                        obj = self._create_object(model_form)
 
 
                 except AbortTransaction:
                 except AbortTransaction:
                     clear_webhooks.send(sender=self)
                     clear_webhooks.send(sender=self)
@@ -723,9 +728,8 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
 
 
             if not model_form.errors:
             if not model_form.errors:
                 logger.info(f"Import object {obj} (PK: {obj.pk})")
                 logger.info(f"Import object {obj} (PK: {obj.pk})")
-                messages.success(request, mark_safe('Imported object: <a href="{}">{}</a>'.format(
-                    obj.get_absolute_url(), obj
-                )))
+                msg = f'Imported object: <a href="{obj.get_absolute_url()}">{obj}</a>'
+                messages.success(request, mark_safe(msg))
 
 
                 if '_addanother' in request.POST:
                 if '_addanother' in request.POST:
                     return redirect(request.get_full_path())
                     return redirect(request.get_full_path())
@@ -733,8 +737,7 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
                 return_url = form.cleaned_data.get('return_url')
                 return_url = form.cleaned_data.get('return_url')
                 if return_url is not None and is_safe_url(url=return_url, allowed_hosts=request.get_host()):
                 if return_url is not None and is_safe_url(url=return_url, allowed_hosts=request.get_host()):
                     return redirect(return_url)
                     return redirect(return_url)
-                else:
-                    return redirect(self.get_return_url(request, obj))
+                return redirect(self.get_return_url(request, obj))
 
 
             else:
             else:
                 logger.debug("Model form validation failed")
                 logger.debug("Model form validation failed")
@@ -799,6 +802,27 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
 
 
         return ImportForm(*args, **kwargs)
         return ImportForm(*args, **kwargs)
 
 
+    def _create_objects(self, form, request):
+        new_objs = []
+        if request.FILES:
+            headers, records = form.cleaned_data['csv_file']
+        else:
+            headers, records = form.cleaned_data['csv']
+
+        for row, data in enumerate(records, start=1):
+            obj_form = self.model_form(data, headers=headers)
+            restrict_form_fields(obj_form, request.user)
+
+            if obj_form.is_valid():
+                obj = self._save_obj(obj_form, request)
+                new_objs.append(obj)
+            else:
+                for field, err in obj_form.errors.items():
+                    form.add_error('csv', f'Row {row} {field}: {err[0]}')
+                raise ValidationError("")
+
+        return new_objs
+
     def _save_obj(self, obj_form, request):
     def _save_obj(self, obj_form, request):
         """
         """
         Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
         Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
@@ -819,7 +843,6 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
 
 
     def post(self, request):
     def post(self, request):
         logger = logging.getLogger('netbox.views.BulkImportView')
         logger = logging.getLogger('netbox.views.BulkImportView')
-        new_objs = []
         form = self._import_form(request.POST, request.FILES)
         form = self._import_form(request.POST, request.FILES)
 
 
         if form.is_valid():
         if form.is_valid():
@@ -828,21 +851,7 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
             try:
             try:
                 # Iterate through CSV data and bind each row to a new model form instance.
                 # Iterate through CSV data and bind each row to a new model form instance.
                 with transaction.atomic():
                 with transaction.atomic():
-                    if request.FILES:
-                        headers, records = form.cleaned_data['csv_file']
-                    else:
-                        headers, records = form.cleaned_data['csv']
-                    for row, data in enumerate(records, start=1):
-                        obj_form = self.model_form(data, headers=headers)
-                        restrict_form_fields(obj_form, request.user)
-
-                        if obj_form.is_valid():
-                            obj = self._save_obj(obj_form, request)
-                            new_objs.append(obj)
-                        else:
-                            for field, err in obj_form.errors.items():
-                                form.add_error('csv', "Row {} {}: {}".format(row, field, err[0]))
-                            raise ValidationError("")
+                    new_objs = self._create_objects(form, request)
 
 
                     # Enforce object-level permissions
                     # Enforce object-level permissions
                     if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):
                     if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):
@@ -886,7 +895,7 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
     Edit objects in bulk.
     Edit objects in bulk.
 
 
     queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
     queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
-    filter: FilterSet to apply when deleting by QuerySet
+    filterset: FilterSet to apply when deleting by QuerySet
     table: The table used to display devices being edited
     table: The table used to display devices being edited
     form: The form class used to edit objects in bulk
     form: The form class used to edit objects in bulk
     template_name: The name of the template
     template_name: The name of the template
@@ -900,6 +909,63 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
     def get_required_permission(self):
     def get_required_permission(self):
         return get_permission_for_model(self.queryset.model, 'change')
         return get_permission_for_model(self.queryset.model, 'change')
 
 
+    def _update_objects(self, form, request):
+        custom_fields = form.custom_fields if hasattr(form, 'custom_fields') else []
+        standard_fields = [
+            field for field in form.fields if field not in custom_fields + ['pk']
+        ]
+        nullified_fields = request.POST.getlist('_nullify')
+        updated_objects = []
+
+        for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']):
+
+            # Take a snapshot of change-logged models
+            if hasattr(obj, 'snapshot'):
+                obj.snapshot()
+
+            # Update standard fields. If a field is listed in _nullify, delete its value.
+            for name in standard_fields:
+
+                try:
+                    model_field = self.queryset.model._meta.get_field(name)
+                except FieldDoesNotExist:
+                    # This form field is used to modify a field rather than set its value directly
+                    model_field = None
+
+                # Handle nullification
+                if name in form.nullable_fields and name in nullified_fields:
+                    if isinstance(model_field, ManyToManyField):
+                        getattr(obj, name).set([])
+                    else:
+                        setattr(obj, name, None if model_field.null else '')
+
+                # ManyToManyFields
+                elif isinstance(model_field, ManyToManyField):
+                    if form.cleaned_data[name]:
+                        getattr(obj, name).set(form.cleaned_data[name])
+                # Normal fields
+                elif name in form.changed_data:
+                    setattr(obj, name, form.cleaned_data[name])
+
+            # Update custom fields
+            for name in custom_fields:
+                if name in form.nullable_fields and name in nullified_fields:
+                    obj.custom_field_data[name] = None
+                elif name in form.changed_data:
+                    obj.custom_field_data[name] = form.cleaned_data[name]
+
+            obj.full_clean()
+            obj.save()
+            updated_objects.append(obj)
+
+            # Add/remove tags
+            if form.cleaned_data.get('add_tags', None):
+                obj.tags.add(*form.cleaned_data['add_tags'])
+            if form.cleaned_data.get('remove_tags', None):
+                obj.tags.remove(*form.cleaned_data['remove_tags'])
+
+        return updated_objects
+
     def get(self, request):
     def get(self, request):
         return redirect(self.get_return_url(request))
         return redirect(self.get_return_url(request))
 
 
@@ -932,78 +998,26 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
 
 
             if form.is_valid():
             if form.is_valid():
                 logger.debug("Form validation was successful")
                 logger.debug("Form validation was successful")
-                custom_fields = form.custom_fields if hasattr(form, 'custom_fields') else []
-                standard_fields = [
-                    field for field in form.fields if field not in custom_fields + ['pk']
-                ]
-                nullified_fields = request.POST.getlist('_nullify')
 
 
                 try:
                 try:
 
 
                     with transaction.atomic():
                     with transaction.atomic():
-
-                        updated_objects = []
-                        for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']):
-
-                            # Take a snapshot of change-logged models
-                            if hasattr(obj, 'snapshot'):
-                                obj.snapshot()
-
-                            # Update standard fields. If a field is listed in _nullify, delete its value.
-                            for name in standard_fields:
-
-                                try:
-                                    model_field = model._meta.get_field(name)
-                                except FieldDoesNotExist:
-                                    # This form field is used to modify a field rather than set its value directly
-                                    model_field = None
-
-                                # Handle nullification
-                                if name in form.nullable_fields and name in nullified_fields:
-                                    if isinstance(model_field, ManyToManyField):
-                                        getattr(obj, name).set([])
-                                    else:
-                                        setattr(obj, name, None if model_field.null else '')
-
-                                # ManyToManyFields
-                                elif isinstance(model_field, ManyToManyField):
-                                    if form.cleaned_data[name]:
-                                        getattr(obj, name).set(form.cleaned_data[name])
-                                # Normal fields
-                                elif name in form.changed_data:
-                                    setattr(obj, name, form.cleaned_data[name])
-
-                            # Update custom fields
-                            for name in custom_fields:
-                                if name in form.nullable_fields and name in nullified_fields:
-                                    obj.custom_field_data[name] = None
-                                elif name in form.changed_data:
-                                    obj.custom_field_data[name] = form.cleaned_data[name]
-
-                            obj.full_clean()
-                            obj.save()
-                            updated_objects.append(obj)
-                            logger.debug(f"Saved {obj} (PK: {obj.pk})")
-
-                            # Add/remove tags
-                            if form.cleaned_data.get('add_tags', None):
-                                obj.tags.add(*form.cleaned_data['add_tags'])
-                            if form.cleaned_data.get('remove_tags', None):
-                                obj.tags.remove(*form.cleaned_data['remove_tags'])
+                        updated_objects = self._update_objects(form, request)
 
 
                         # Enforce object-level permissions
                         # Enforce object-level permissions
-                        if self.queryset.filter(pk__in=[obj.pk for obj in updated_objects]).count() != len(updated_objects):
+                        object_count = self.queryset.filter(pk__in=[obj.pk for obj in updated_objects]).count()
+                        if object_count != len(updated_objects):
                             raise PermissionsViolation
                             raise PermissionsViolation
 
 
                     if updated_objects:
                     if updated_objects:
-                        msg = 'Updated {} {}'.format(len(updated_objects), model._meta.verbose_name_plural)
+                        msg = f'Updated {len(updated_objects)} {model._meta.verbose_name_plural}'
                         logger.info(msg)
                         logger.info(msg)
                         messages.success(self.request, msg)
                         messages.success(self.request, msg)
 
 
                     return redirect(self.get_return_url(request))
                     return redirect(self.get_return_url(request))
 
 
                 except ValidationError as e:
                 except ValidationError as e:
-                    messages.error(self.request, "{} failed validation: {}".format(obj, ", ".join(e.messages)))
+                    messages.error(self.request, ", ".join(e.messages))
                     clear_webhooks.send(sender=self)
                     clear_webhooks.send(sender=self)
 
 
                 except PermissionsViolation:
                 except PermissionsViolation:
@@ -1016,7 +1030,6 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
                 logger.debug("Form validation failed")
                 logger.debug("Form validation failed")
 
 
         else:
         else:
-
             form = self.form(model, initial=initial_data)
             form = self.form(model, initial=initial_data)
             restrict_form_fields(form, request.user)
             restrict_form_fields(form, request.user)
 
 
@@ -1037,6 +1050,9 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
 class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
 class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
     """
     """
     An extendable view for renaming objects in bulk.
     An extendable view for renaming objects in bulk.
+
+    queryset: QuerySet of objects being renamed
+    template_name: The name of the template
     """
     """
     queryset = None
     queryset = None
     template_name = 'generic/object_bulk_rename.html'
     template_name = 'generic/object_bulk_rename.html'
@@ -1056,6 +1072,29 @@ class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
     def get_required_permission(self):
     def get_required_permission(self):
         return get_permission_for_model(self.queryset.model, 'change')
         return get_permission_for_model(self.queryset.model, 'change')
 
 
+    def _rename_objects(self, form, selected_objects):
+        renamed_pks = []
+
+        for obj in selected_objects:
+
+            # Take a snapshot of change-logged models
+            if hasattr(obj, 'snapshot'):
+                obj.snapshot()
+
+            find = form.cleaned_data['find']
+            replace = form.cleaned_data['replace']
+            if form.cleaned_data['use_regex']:
+                try:
+                    obj.new_name = re.sub(find, replace, obj.name)
+                # Catch regex group reference errors
+                except re.error:
+                    obj.new_name = obj.name
+            else:
+                obj.new_name = obj.name.replace(find, replace)
+            renamed_pks.append(obj.pk)
+
+        return renamed_pks
+
     def post(self, request):
     def post(self, request):
         logger = logging.getLogger('netbox.views.BulkRenameView')
         logger = logging.getLogger('netbox.views.BulkRenameView')
 
 
@@ -1066,24 +1105,7 @@ class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
             if form.is_valid():
             if form.is_valid():
                 try:
                 try:
                     with transaction.atomic():
                     with transaction.atomic():
-                        renamed_pks = []
-                        for obj in selected_objects:
-
-                            # Take a snapshot of change-logged models
-                            if hasattr(obj, 'snapshot'):
-                                obj.snapshot()
-
-                            find = form.cleaned_data['find']
-                            replace = form.cleaned_data['replace']
-                            if form.cleaned_data['use_regex']:
-                                try:
-                                    obj.new_name = re.sub(find, replace, obj.name)
-                                # Catch regex group reference errors
-                                except re.error:
-                                    obj.new_name = obj.name
-                            else:
-                                obj.new_name = obj.name.replace(find, replace)
-                            renamed_pks.append(obj.pk)
+                        renamed_pks = self._rename_objects(form, selected_objects)
 
 
                         if '_apply' in request.POST:
                         if '_apply' in request.POST:
                             for obj in selected_objects:
                             for obj in selected_objects:
@@ -1094,10 +1116,8 @@ class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
                             if self.queryset.filter(pk__in=renamed_pks).count() != len(selected_objects):
                             if self.queryset.filter(pk__in=renamed_pks).count() != len(selected_objects):
                                 raise PermissionsViolation
                                 raise PermissionsViolation
 
 
-                            messages.success(request, "Renamed {} {}".format(
-                                len(selected_objects),
-                                self.queryset.model._meta.verbose_name_plural
-                            ))
+                            model_name = self.queryset.model._meta.verbose_name_plural
+                            messages.success(request, f"Renamed {len(selected_objects)} {model_name}")
                             return redirect(self.get_return_url(request))
                             return redirect(self.get_return_url(request))
 
 
                 except PermissionsViolation:
                 except PermissionsViolation:
@@ -1123,7 +1143,7 @@ class BulkDeleteView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
     Delete objects in bulk.
     Delete objects in bulk.
 
 
     queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
     queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
-    filter: FilterSet to apply when deleting by QuerySet
+    filterset: FilterSet to apply when deleting by QuerySet
     table: The table used to display devices being deleted
     table: The table used to display devices being deleted
     form: The form class used to delete objects in bulk
     form: The form class used to delete objects in bulk
     template_name: The name of the template
     template_name: The name of the template