Jelajahi Sumber

Closes #1851: Standardize usage of GetReturnURLMixin

Jeremy Stretch 7 tahun lalu
induk
melakukan
cd56e51a61

+ 1 - 4
netbox/circuits/views.py

@@ -6,7 +6,6 @@ from django.contrib.auth.mixins import PermissionRequiredMixin
 from django.db import transaction
 from django.db import transaction
 from django.db.models import Count
 from django.db.models import Count
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
-from django.urls import reverse
 from django.views.generic import View
 from django.views.generic import View
 
 
 from extras.models import Graph, GRAPH_TYPE_PROVIDER
 from extras.models import Graph, GRAPH_TYPE_PROVIDER
@@ -106,9 +105,7 @@ class CircuitTypeCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'circuits.add_circuittype'
     permission_required = 'circuits.add_circuittype'
     model = CircuitType
     model = CircuitType
     model_form = forms.CircuitTypeForm
     model_form = forms.CircuitTypeForm
-
-    def get_return_url(self, request, obj):
-        return reverse('circuits:circuittype_list')
+    default_return_url = 'circuits:circuittype_list'
 
 
 
 
 class CircuitTypeEditView(CircuitTypeCreateView):
 class CircuitTypeEditView(CircuitTypeCreateView):

+ 10 - 26
netbox/dcim/views.py

@@ -12,7 +12,7 @@ from django.http import HttpResponseRedirect
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
 from django.urls import reverse
 from django.urls import reverse
 from django.utils.html import escape
 from django.utils.html import escape
-from django.utils.http import is_safe_url, urlencode
+from django.utils.http import urlencode
 from django.utils.safestring import mark_safe
 from django.utils.safestring import mark_safe
 from django.views.generic import View
 from django.views.generic import View
 from natsort import natsorted
 from natsort import natsorted
@@ -38,7 +38,7 @@ from .models import (
 )
 )
 
 
 
 
-class BulkRenameView(View):
+class BulkRenameView(GetReturnURLMixin, View):
     """
     """
     An extendable view for renaming device components in bulk.
     An extendable view for renaming device components in bulk.
     """
     """
@@ -50,10 +50,6 @@ class BulkRenameView(View):
 
 
         model = self.queryset.model
         model = self.queryset.model
 
 
-        return_url = request.GET.get('return_url')
-        if not return_url or not is_safe_url(url=return_url, host=request.get_host()):
-            return_url = 'home'
-
         if '_preview' in request.POST or '_apply' in request.POST:
         if '_preview' in request.POST or '_apply' in request.POST:
             form = self.form(request.POST, initial={'pk': request.POST.getlist('pk')})
             form = self.form(request.POST, initial={'pk': request.POST.getlist('pk')})
             selected_objects = self.queryset.filter(pk__in=form.initial['pk'])
             selected_objects = self.queryset.filter(pk__in=form.initial['pk'])
@@ -70,7 +66,7 @@ class BulkRenameView(View):
                         len(selected_objects),
                         len(selected_objects),
                         model._meta.verbose_name_plural
                         model._meta.verbose_name_plural
                     ))
                     ))
-                    return redirect(return_url)
+                    return redirect(self.get_return_url(request))
 
 
         else:
         else:
             form = self.form(initial={'pk': request.POST.getlist('pk')})
             form = self.form(initial={'pk': request.POST.getlist('pk')})
@@ -80,7 +76,7 @@ class BulkRenameView(View):
             'form': form,
             'form': form,
             'obj_type_plural': model._meta.verbose_name_plural,
             'obj_type_plural': model._meta.verbose_name_plural,
             'selected_objects': selected_objects,
             'selected_objects': selected_objects,
-            'return_url': return_url,
+            'return_url': self.get_return_url(request),
         })
         })
 
 
 
 
@@ -138,9 +134,7 @@ class RegionCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'dcim.add_region'
     permission_required = 'dcim.add_region'
     model = Region
     model = Region
     model_form = forms.RegionForm
     model_form = forms.RegionForm
-
-    def get_return_url(self, request, obj):
-        return reverse('dcim:region_list')
+    default_return_url = 'dcim:region_list'
 
 
 
 
 class RegionEditView(RegionCreateView):
 class RegionEditView(RegionCreateView):
@@ -252,9 +246,7 @@ class RackGroupCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'dcim.add_rackgroup'
     permission_required = 'dcim.add_rackgroup'
     model = RackGroup
     model = RackGroup
     model_form = forms.RackGroupForm
     model_form = forms.RackGroupForm
-
-    def get_return_url(self, request, obj):
-        return reverse('dcim:rackgroup_list')
+    default_return_url = 'dcim:rackgroup_list'
 
 
 
 
 class RackGroupEditView(RackGroupCreateView):
 class RackGroupEditView(RackGroupCreateView):
@@ -291,9 +283,7 @@ class RackRoleCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'dcim.add_rackrole'
     permission_required = 'dcim.add_rackrole'
     model = RackRole
     model = RackRole
     model_form = forms.RackRoleForm
     model_form = forms.RackRoleForm
-
-    def get_return_url(self, request, obj):
-        return reverse('dcim:rackrole_list')
+    default_return_url = 'dcim:rackrole_list'
 
 
 
 
 class RackRoleEditView(RackRoleCreateView):
 class RackRoleEditView(RackRoleCreateView):
@@ -515,9 +505,7 @@ class ManufacturerCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'dcim.add_manufacturer'
     permission_required = 'dcim.add_manufacturer'
     model = Manufacturer
     model = Manufacturer
     model_form = forms.ManufacturerForm
     model_form = forms.ManufacturerForm
-
-    def get_return_url(self, request, obj):
-        return reverse('dcim:manufacturer_list')
+    default_return_url = 'dcim:manufacturer_list'
 
 
 
 
 class ManufacturerEditView(ManufacturerCreateView):
 class ManufacturerEditView(ManufacturerCreateView):
@@ -777,9 +765,7 @@ class DeviceRoleCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'dcim.add_devicerole'
     permission_required = 'dcim.add_devicerole'
     model = DeviceRole
     model = DeviceRole
     model_form = forms.DeviceRoleForm
     model_form = forms.DeviceRoleForm
-
-    def get_return_url(self, request, obj):
-        return reverse('dcim:devicerole_list')
+    default_return_url = 'dcim:devicerole_list'
 
 
 
 
 class DeviceRoleEditView(DeviceRoleCreateView):
 class DeviceRoleEditView(DeviceRoleCreateView):
@@ -815,9 +801,7 @@ class PlatformCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'dcim.add_platform'
     permission_required = 'dcim.add_platform'
     model = Platform
     model = Platform
     model_form = forms.PlatformForm
     model_form = forms.PlatformForm
-
-    def get_return_url(self, request, obj):
-        return reverse('dcim:platform_list')
+    default_return_url = 'dcim:platform_list'
 
 
 
 
 class PlatformEditView(PlatformCreateView):
 class PlatformEditView(PlatformCreateView):

+ 3 - 10
netbox/ipam/views.py

@@ -5,7 +5,6 @@ from django.conf import settings
 from django.contrib.auth.mixins import PermissionRequiredMixin
 from django.contrib.auth.mixins import PermissionRequiredMixin
 from django.db.models import Count, Q
 from django.db.models import Count, Q
 from django.shortcuts import get_object_or_404, redirect, render
 from django.shortcuts import get_object_or_404, redirect, render
-from django.urls import reverse
 from django.views.generic import View
 from django.views.generic import View
 from django_tables2 import RequestConfig
 from django_tables2 import RequestConfig
 
 
@@ -248,9 +247,7 @@ class RIRCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'ipam.add_rir'
     permission_required = 'ipam.add_rir'
     model = RIR
     model = RIR
     model_form = forms.RIRForm
     model_form = forms.RIRForm
-
-    def get_return_url(self, request, obj):
-        return reverse('ipam:rir_list')
+    default_return_url = 'ipam:rir_list'
 
 
 
 
 class RIREditView(RIRCreateView):
 class RIREditView(RIRCreateView):
@@ -401,9 +398,7 @@ class RoleCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'ipam.add_role'
     permission_required = 'ipam.add_role'
     model = Role
     model = Role
     model_form = forms.RoleForm
     model_form = forms.RoleForm
-
-    def get_return_url(self, request, obj):
-        return reverse('ipam:role_list')
+    default_return_url = 'ipam:role_list'
 
 
 
 
 class RoleEditView(RoleCreateView):
 class RoleEditView(RoleCreateView):
@@ -799,9 +794,7 @@ class VLANGroupCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'ipam.add_vlangroup'
     permission_required = 'ipam.add_vlangroup'
     model = VLANGroup
     model = VLANGroup
     model_form = forms.VLANGroupForm
     model_form = forms.VLANGroupForm
-
-    def get_return_url(self, request, obj):
-        return reverse('ipam:vlangroup_list')
+    default_return_url = 'ipam:vlangroup_list'
 
 
 
 
 class VLANGroupEditView(VLANGroupCreateView):
 class VLANGroupEditView(VLANGroupCreateView):

+ 2 - 4
netbox/secrets/views.py

@@ -44,9 +44,7 @@ class SecretRoleCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'secrets.add_secretrole'
     permission_required = 'secrets.add_secretrole'
     model = SecretRole
     model = SecretRole
     model_form = forms.SecretRoleForm
     model_form = forms.SecretRoleForm
-
-    def get_return_url(self, request, obj):
-        return reverse('secrets:secretrole_list')
+    default_return_url = 'secrets:secretrole_list'
 
 
 
 
 class SecretRoleEditView(SecretRoleCreateView):
 class SecretRoleEditView(SecretRoleCreateView):
@@ -244,7 +242,7 @@ class SecretBulkImportView(BulkImportView):
             'form': self._import_form(request.POST),
             'form': self._import_form(request.POST),
             'fields': self.model_form().fields,
             'fields': self.model_form().fields,
             'obj_type': self.model_form._meta.model._meta.verbose_name,
             'obj_type': self.model_form._meta.model._meta.verbose_name,
-            'return_url': self.default_return_url,
+            'return_url': self.get_return_url(request),
         })
         })
 
 
 
 

+ 1 - 1
netbox/templates/import_success.html

@@ -8,6 +8,6 @@
         Import more
         Import more
     </a>
     </a>
     {% if return_url %}
     {% if return_url %}
-        <a href="{% url return_url %}" class="btn btn-default">View All</a>
+        <a href="{{ return_url }}" class="btn btn-default">View All</a>
     {% endif %}
     {% endif %}
 {% endblock %}
 {% endblock %}

+ 1 - 1
netbox/templates/utilities/obj_import.html

@@ -22,7 +22,7 @@
                 <div class="col-md-12 text-right">
                 <div class="col-md-12 text-right">
 		            <button type="submit" class="btn btn-primary">Submit</button>
 		            <button type="submit" class="btn btn-primary">Submit</button>
 		            {% if return_url %}
 		            {% if return_url %}
-                        <a href="{% url return_url %}" class="btn btn-default">Cancel</a>
+                        <a href="{{ return_url }}" class="btn btn-default">Cancel</a>
                     {% endif %}
                     {% endif %}
                 </div>
                 </div>
             </div>
             </div>

+ 1 - 4
netbox/tenancy/views.py

@@ -3,7 +3,6 @@ from __future__ import unicode_literals
 from django.contrib.auth.mixins import PermissionRequiredMixin
 from django.contrib.auth.mixins import PermissionRequiredMixin
 from django.db.models import Count, Q
 from django.db.models import Count, Q
 from django.shortcuts import get_object_or_404, render
 from django.shortcuts import get_object_or_404, render
-from django.urls import reverse
 from django.views.generic import View
 from django.views.generic import View
 
 
 from circuits.models import Circuit
 from circuits.models import Circuit
@@ -31,9 +30,7 @@ class TenantGroupCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'tenancy.add_tenantgroup'
     permission_required = 'tenancy.add_tenantgroup'
     model = TenantGroup
     model = TenantGroup
     model_form = forms.TenantGroupForm
     model_form = forms.TenantGroupForm
-
-    def get_return_url(self, request, obj):
-        return reverse('tenancy:tenantgroup_list')
+    default_return_url = 'tenancy:tenantgroup_list'
 
 
 
 
 class TenantGroupEditView(TenantGroupCreateView):
 class TenantGroupEditView(TenantGroupCreateView):

+ 36 - 63
netbox/utilities/views.py

@@ -52,14 +52,22 @@ class GetReturnURLMixin(object):
     """
     """
     default_return_url = None
     default_return_url = None
 
 
-    def get_return_url(self, request, obj):
+    def get_return_url(self, request, obj=None):
+
+        # First, see if `return_url` was specified as a query parameter. Use it only if it's considered safe.
         query_param = request.GET.get('return_url')
         query_param = request.GET.get('return_url')
         if query_param and is_safe_url(url=query_param, host=request.get_host()):
         if query_param and is_safe_url(url=query_param, host=request.get_host()):
             return query_param
             return query_param
+
+        # Next, check if the object being modified (if any) has an absolute URL.
         elif obj.pk and hasattr(obj, 'get_absolute_url'):
         elif obj.pk and hasattr(obj, 'get_absolute_url'):
             return obj.get_absolute_url()
             return obj.get_absolute_url()
+
+        # Fall back to the default URL (if specified) for the view.
         elif self.default_return_url is not None:
         elif self.default_return_url is not None:
             return reverse(self.default_return_url)
             return reverse(self.default_return_url)
+
+        # If all else fails, return home. Ideally this should never happen.
         return reverse('home')
         return reverse('home')
 
 
 
 
@@ -159,7 +167,6 @@ class ObjectEditView(GetReturnURLMixin, View):
     model: The model of the object being edited
     model: The model of the object being edited
     model_form: The form used to create or edit the object
     model_form: The form used to create or edit the object
     template_name: The name of the template
     template_name: The name of the template
-    default_return_url: The name of the URL used to display a list of this object type
     """
     """
     model = None
     model = None
     model_form = None
     model_form = None
@@ -236,7 +243,6 @@ class ObjectDeleteView(GetReturnURLMixin, View):
 
 
     model: The model of the object being deleted
     model: The model of the object being deleted
     template_name: The name of the template
     template_name: The name of the template
-    default_return_url: Name of the URL to which the user is redirected after deleting the object
     """
     """
     model = None
     model = None
     template_name = 'utilities/obj_delete.html'
     template_name = 'utilities/obj_delete.html'
@@ -289,20 +295,18 @@ class ObjectDeleteView(GetReturnURLMixin, View):
         })
         })
 
 
 
 
-class BulkCreateView(View):
+class BulkCreateView(GetReturnURLMixin, View):
     """
     """
     Create new objects in bulk.
     Create new objects in bulk.
 
 
     form: Form class which provides the `pattern` field
     form: Form class which provides the `pattern` field
     model_form: The ModelForm used to create individual objects
     model_form: The ModelForm used to create individual objects
     template_name: The name of the template
     template_name: The name of the template
-    default_return_url: Name of the URL to which the user is redirected after creating the objects
     """
     """
     form = None
     form = None
     model_form = None
     model_form = None
     pattern_target = ''
     pattern_target = ''
     template_name = None
     template_name = None
-    default_return_url = 'home'
 
 
     def get(self, request):
     def get(self, request):
 
 
@@ -319,7 +323,7 @@ class BulkCreateView(View):
             'obj_type': self.model_form._meta.model._meta.verbose_name,
             'obj_type': self.model_form._meta.model._meta.verbose_name,
             'form': form,
             'form': form,
             'model_form': model_form,
             'model_form': model_form,
-            'return_url': reverse(self.default_return_url),
+            'return_url': self.get_return_url(request),
         })
         })
 
 
     def post(self, request):
     def post(self, request):
@@ -362,7 +366,7 @@ class BulkCreateView(View):
 
 
                     if '_addanother' in request.POST:
                     if '_addanother' in request.POST:
                         return redirect(request.path)
                         return redirect(request.path)
-                    return redirect(self.default_return_url)
+                    return redirect(self.get_return_url(request))
 
 
             except IntegrityError:
             except IntegrityError:
                 pass
                 pass
@@ -371,23 +375,21 @@ class BulkCreateView(View):
             'form': form,
             'form': form,
             'model_form': model_form,
             'model_form': model_form,
             'obj_type': model._meta.verbose_name,
             'obj_type': model._meta.verbose_name,
-            'return_url': reverse(self.default_return_url),
+            'return_url': self.get_return_url(request),
         })
         })
 
 
 
 
-class BulkImportView(View):
+class BulkImportView(GetReturnURLMixin, View):
     """
     """
     Import objects in bulk (CSV format).
     Import objects in bulk (CSV format).
 
 
     model_form: The form used to create each imported object
     model_form: The form used to create each imported object
     table: The django-tables2 Table used to render the list of imported objects
     table: The django-tables2 Table used to render the list of imported objects
     template_name: The name of the template
     template_name: The name of the template
-    default_return_url: The name of the URL to use for the cancel button
     widget_attrs: A dict of attributes to apply to the import widget (e.g. to require a session key)
     widget_attrs: A dict of attributes to apply to the import widget (e.g. to require a session key)
     """
     """
     model_form = None
     model_form = None
     table = None
     table = None
-    default_return_url = None
     template_name = 'utilities/obj_import.html'
     template_name = 'utilities/obj_import.html'
     widget_attrs = {}
     widget_attrs = {}
 
 
@@ -413,7 +415,7 @@ class BulkImportView(View):
             'form': self._import_form(),
             'form': self._import_form(),
             'fields': self.model_form().fields,
             'fields': self.model_form().fields,
             'obj_type': self.model_form._meta.model._meta.verbose_name,
             'obj_type': self.model_form._meta.model._meta.verbose_name,
-            'return_url': self.default_return_url,
+            'return_url': self.get_return_url(request),
         })
         })
 
 
     def post(self, request):
     def post(self, request):
@@ -446,7 +448,7 @@ class BulkImportView(View):
 
 
                     return render(request, "import_success.html", {
                     return render(request, "import_success.html", {
                         'table': obj_table,
                         'table': obj_table,
-                        'return_url': self.default_return_url,
+                        'return_url': self.get_return_url(request),
                     })
                     })
 
 
             except ValidationError:
             except ValidationError:
@@ -456,11 +458,11 @@ class BulkImportView(View):
             'form': form,
             'form': form,
             'fields': self.model_form().fields,
             'fields': self.model_form().fields,
             'obj_type': self.model_form._meta.model._meta.verbose_name,
             'obj_type': self.model_form._meta.model._meta.verbose_name,
-            'return_url': self.default_return_url,
+            'return_url': self.get_return_url(request),
         })
         })
 
 
 
 
-class BulkEditView(View):
+class BulkEditView(GetReturnURLMixin, View):
     """
     """
     Edit objects in bulk.
     Edit objects in bulk.
 
 
@@ -471,8 +473,6 @@ class BulkEditView(View):
     table: The table used to display devices being edited
     table: The table used to display devices being edited
     form: The form class used to edit objects in bulk
     form: The form class used to edit objects in bulk
     template_name: The name of the template
     template_name: The name of the template
-    default_return_url: Name of the URL to which the user is redirected after editing the objects (can be overridden by
-                        POSTing return_url)
     """
     """
     cls = None
     cls = None
     parent_cls = None
     parent_cls = None
@@ -481,10 +481,9 @@ class BulkEditView(View):
     table = None
     table = None
     form = None
     form = None
     template_name = 'utilities/obj_bulk_edit.html'
     template_name = 'utilities/obj_bulk_edit.html'
-    default_return_url = 'home'
 
 
     def get(self, request):
     def get(self, request):
-        return redirect(self.default_return_url)
+        return redirect(self.get_return_url(request))
 
 
     def post(self, request, **kwargs):
     def post(self, request, **kwargs):
 
 
@@ -494,15 +493,6 @@ class BulkEditView(View):
         else:
         else:
             parent_obj = None
             parent_obj = None
 
 
-        # Determine URL to redirect users upon modification of objects
-        posted_return_url = request.POST.get('return_url')
-        if posted_return_url and is_safe_url(url=posted_return_url, host=request.get_host()):
-            return_url = posted_return_url
-        elif parent_obj:
-            return_url = parent_obj.get_absolute_url()
-        else:
-            return_url = reverse(self.default_return_url)
-
         # Are we editing *all* objects in the queryset or just a selected subset?
         # Are we editing *all* objects in the queryset or just a selected subset?
         if request.POST.get('_all') and self.filter is not None:
         if request.POST.get('_all') and self.filter is not None:
             pk_list = [obj.pk for obj in self.filter(request.GET, self.cls.objects.only('pk')).qs]
             pk_list = [obj.pk for obj in self.filter(request.GET, self.cls.objects.only('pk')).qs]
@@ -559,7 +549,7 @@ class BulkEditView(View):
                         msg = 'Updated {} {}'.format(updated_count, self.cls._meta.verbose_name_plural)
                         msg = 'Updated {} {}'.format(updated_count, self.cls._meta.verbose_name_plural)
                         messages.success(self.request, msg)
                         messages.success(self.request, msg)
 
 
-                    return redirect(return_url)
+                    return redirect(self.get_return_url(request))
 
 
                 except ValidationError as e:
                 except ValidationError as e:
                     messages.error(self.request, "{} failed validation: {}".format(obj, e))
                     messages.error(self.request, "{} failed validation: {}".format(obj, e))
@@ -574,17 +564,17 @@ class BulkEditView(View):
         table = self.table(queryset.filter(pk__in=pk_list), orderable=False)
         table = self.table(queryset.filter(pk__in=pk_list), orderable=False)
         if not table.rows:
         if not table.rows:
             messages.warning(request, "No {} were selected.".format(self.cls._meta.verbose_name_plural))
             messages.warning(request, "No {} were selected.".format(self.cls._meta.verbose_name_plural))
-            return redirect(return_url)
+            return redirect(self.get_return_url(request))
 
 
         return render(request, self.template_name, {
         return render(request, self.template_name, {
             'form': form,
             'form': form,
             'table': table,
             'table': table,
             'obj_type_plural': self.cls._meta.verbose_name_plural,
             'obj_type_plural': self.cls._meta.verbose_name_plural,
-            'return_url': return_url,
+            'return_url': self.get_return_url(request),
         })
         })
 
 
 
 
-class BulkDeleteView(View):
+class BulkDeleteView(GetReturnURLMixin, View):
     """
     """
     Delete objects in bulk.
     Delete objects in bulk.
 
 
@@ -595,8 +585,6 @@ class BulkDeleteView(View):
     table: The table used to display devices being deleted
     table: The table used to display devices being deleted
     form: The form class used to delete objects in bulk
     form: The form class used to delete objects in bulk
     template_name: The name of the template
     template_name: The name of the template
-    default_return_url: Name of the URL to which the user is redirected after deleting the objects (can be overriden by
-                        POSTing return_url)
     """
     """
     cls = None
     cls = None
     parent_cls = None
     parent_cls = None
@@ -605,10 +593,9 @@ class BulkDeleteView(View):
     table = None
     table = None
     form = None
     form = None
     template_name = 'utilities/obj_bulk_delete.html'
     template_name = 'utilities/obj_bulk_delete.html'
-    default_return_url = 'home'
 
 
     def get(self, request):
     def get(self, request):
-        return redirect(self.default_return_url)
+        return redirect(self.get_return_url(request))
 
 
     def post(self, request, **kwargs):
     def post(self, request, **kwargs):
 
 
@@ -618,15 +605,6 @@ class BulkDeleteView(View):
         else:
         else:
             parent_obj = None
             parent_obj = None
 
 
-        # Determine URL to redirect users upon deletion of objects
-        posted_return_url = request.POST.get('return_url')
-        if posted_return_url and is_safe_url(url=posted_return_url, host=request.get_host()):
-            return_url = posted_return_url
-        elif parent_obj:
-            return_url = parent_obj.get_absolute_url()
-        else:
-            return_url = reverse(self.default_return_url)
-
         # Are we deleting *all* objects in the queryset or just a selected subset?
         # Are we deleting *all* objects in the queryset or just a selected subset?
         if request.POST.get('_all'):
         if request.POST.get('_all'):
             if self.filter is not None:
             if self.filter is not None:
@@ -648,28 +626,31 @@ class BulkDeleteView(View):
                     deleted_count = queryset.delete()[1][self.cls._meta.label]
                     deleted_count = queryset.delete()[1][self.cls._meta.label]
                 except ProtectedError as e:
                 except ProtectedError as e:
                     handle_protectederror(list(queryset), request, e)
                     handle_protectederror(list(queryset), request, e)
-                    return redirect(return_url)
+                    return redirect(self.get_return_url(request))
 
 
                 msg = 'Deleted {} {}'.format(deleted_count, self.cls._meta.verbose_name_plural)
                 msg = 'Deleted {} {}'.format(deleted_count, self.cls._meta.verbose_name_plural)
                 messages.success(request, msg)
                 messages.success(request, msg)
-                return redirect(return_url)
+                return redirect(self.get_return_url(request))
 
 
         else:
         else:
-            form = form_cls(initial={'pk': pk_list, 'return_url': return_url})
+            form = form_cls(initial={
+                'pk': pk_list,
+                'return_url': self.get_return_url(request),
+            })
 
 
         # Retrieve objects being deleted
         # Retrieve objects being deleted
         queryset = self.queryset or self.cls.objects.all()
         queryset = self.queryset or self.cls.objects.all()
         table = self.table(queryset.filter(pk__in=pk_list), orderable=False)
         table = self.table(queryset.filter(pk__in=pk_list), orderable=False)
         if not table.rows:
         if not table.rows:
             messages.warning(request, "No {} were selected for deletion.".format(self.cls._meta.verbose_name_plural))
             messages.warning(request, "No {} were selected for deletion.".format(self.cls._meta.verbose_name_plural))
-            return redirect(return_url)
+            return redirect(self.get_return_url(request))
 
 
         return render(request, self.template_name, {
         return render(request, self.template_name, {
             'form': form,
             'form': form,
             'parent_obj': parent_obj,
             'parent_obj': parent_obj,
             'obj_type_plural': self.cls._meta.verbose_name_plural,
             'obj_type_plural': self.cls._meta.verbose_name_plural,
             'table': table,
             'table': table,
-            'return_url': return_url,
+            'return_url': self.get_return_url(request),
         })
         })
 
 
     def get_form(self):
     def get_form(self):
@@ -785,7 +766,7 @@ class ComponentCreateView(View):
         })
         })
 
 
 
 
-class BulkComponentCreateView(View):
+class BulkComponentCreateView(GetReturnURLMixin, View):
     """
     """
     Add one or more components (e.g. interfaces, console ports, etc.) to a set of Devices or VirtualMachines.
     Add one or more components (e.g. interfaces, console ports, etc.) to a set of Devices or VirtualMachines.
     """
     """
@@ -797,7 +778,6 @@ class BulkComponentCreateView(View):
     filter = None
     filter = None
     table = None
     table = None
     template_name = 'utilities/obj_bulk_add_component.html'
     template_name = 'utilities/obj_bulk_add_component.html'
-    default_return_url = 'home'
 
 
     def post(self, request):
     def post(self, request):
 
 
@@ -807,17 +787,10 @@ class BulkComponentCreateView(View):
         else:
         else:
             pk_list = [int(pk) for pk in request.POST.getlist('pk')]
             pk_list = [int(pk) for pk in request.POST.getlist('pk')]
 
 
-        # Determine URL to redirect users upon modification of objects
-        posted_return_url = request.POST.get('return_url')
-        if posted_return_url and is_safe_url(url=posted_return_url, host=request.get_host()):
-            return_url = posted_return_url
-        else:
-            return_url = reverse(self.default_return_url)
-
         selected_objects = self.parent_model.objects.filter(pk__in=pk_list)
         selected_objects = self.parent_model.objects.filter(pk__in=pk_list)
         if not selected_objects:
         if not selected_objects:
             messages.warning(request, "No {} were selected.".format(self.parent_model._meta.verbose_name_plural))
             messages.warning(request, "No {} were selected.".format(self.parent_model._meta.verbose_name_plural))
-            return redirect(return_url)
+            return redirect(self.get_return_url(request))
         table = self.table(selected_objects)
         table = self.table(selected_objects)
 
 
         if '_create' in request.POST:
         if '_create' in request.POST:
@@ -855,7 +828,7 @@ class BulkComponentCreateView(View):
                         len(form.cleaned_data['pk']),
                         len(form.cleaned_data['pk']),
                         self.parent_model._meta.verbose_name_plural
                         self.parent_model._meta.verbose_name_plural
                     ))
                     ))
-                    return redirect(return_url)
+                    return redirect(self.get_return_url(request))
 
 
         else:
         else:
             form = self.form(initial={'pk': pk_list})
             form = self.form(initial={'pk': pk_list})
@@ -864,5 +837,5 @@ class BulkComponentCreateView(View):
             'form': form,
             'form': form,
             'component_name': self.model._meta.verbose_name_plural,
             'component_name': self.model._meta.verbose_name_plural,
             'table': table,
             'table': table,
-            'return_url': reverse(self.default_return_url),
+            'return_url': self.get_return_url(request),
         })
         })

+ 2 - 6
netbox/virtualization/views.py

@@ -33,9 +33,7 @@ class ClusterTypeCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'virtualization.add_clustertype'
     permission_required = 'virtualization.add_clustertype'
     model = ClusterType
     model = ClusterType
     model_form = forms.ClusterTypeForm
     model_form = forms.ClusterTypeForm
-
-    def get_return_url(self, request, obj):
-        return reverse('virtualization:clustertype_list')
+    default_return_url = 'virtualization:clustertype_list'
 
 
 
 
 class ClusterTypeEditView(ClusterTypeCreateView):
 class ClusterTypeEditView(ClusterTypeCreateView):
@@ -71,9 +69,7 @@ class ClusterGroupCreateView(PermissionRequiredMixin, ObjectEditView):
     permission_required = 'virtualization.add_clustergroup'
     permission_required = 'virtualization.add_clustergroup'
     model = ClusterGroup
     model = ClusterGroup
     model_form = forms.ClusterGroupForm
     model_form = forms.ClusterGroupForm
-
-    def get_return_url(self, request, obj):
-        return reverse('virtualization:clustergroup_list')
+    default_return_url = 'virtualization:clustergroup_list'
 
 
 
 
 class ClusterGroupEditView(ClusterGroupCreateView):
 class ClusterGroupEditView(ClusterGroupCreateView):