Просмотр исходного кода

changed name of csv_file variable and started work on ValidationError

Alyssa Bigley 4 лет назад
Родитель
Сommit
55b7cf21cc
2 измененных файлов с 31 добавлено и 21 удалено
  1. 9 3
      netbox/netbox/views/generic.py
  2. 22 18
      netbox/utilities/forms/fields.py

+ 9 - 3
netbox/netbox/views/generic.py

@@ -665,10 +665,16 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
                 from_form=self.model_form,
                 widget=Textarea(attrs=self.widget_attrs)
             )
-            upload_csv = CSVFileField(
+            csv_file = CSVFileField(
+                label="CSV file",
                 from_form=self.model_form,
                 required=False
             )
+            def used_both_methods(self):
+                if self.cleaned_data['csv_file'][1] and self.cleaned_data['csv'][1]:
+                    raise ValidationError('')
+                return False
+
         return ImportForm(*args, **kwargs)
 
     def _save_obj(self, obj_form, request):
@@ -694,14 +700,14 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
         new_objs = []
         form = self._import_form(request.POST, request.FILES)
 
-        if form.is_valid():
+        if form.is_valid() and not form.used_both_methods():
             logger.debug("Form validation was successful")
 
             try:
                 # Iterate through CSV data and bind each row to a new model form instance.
                 with transaction.atomic():
                     if request.FILES:
-                        headers, records = form.cleaned_data['upload_csv']
+                        headers, records = form.cleaned_data['csv_file']
                     else:
                         headers, records = form.cleaned_data['csv']
                     for row, data in enumerate(records, start=1):

+ 22 - 18
netbox/utilities/forms/fields.py

@@ -246,35 +246,39 @@ class CSVFileField(forms.FileField):
     def to_python(self, file):
 
         records = []
-        csv_str = file.read().decode('utf-8')
-        reader = csv.reader(csv_str.splitlines())
+        if file:
+            csv_str = file.read().decode('utf-8')
+            reader = csv.reader(csv_str.splitlines())
 
         # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
         # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
         # `site.slug` header, to indicate the related site is being referenced by its slug.
 
         headers = {}
-        for header in next(reader):
-            if '.' in header:
-                field, to_field = header.split('.', 1)
-                headers[field] = to_field
-            else:
-                headers[header] = None
-
-        # Parse CSV rows into a list of dictionaries mapped from the column headers.
-        for i, row in enumerate(reader, start=1):
-            if len(row) != len(headers):
-                raise forms.ValidationError(
-                    f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
-                )
-            row = [col.strip() for col in row]
-            record = dict(zip(headers.keys(), row))
-            records.append(record)
+        if file:
+            for header in next(reader):
+                if '.' in header:
+                    field, to_field = header.split('.', 1)
+                    headers[field] = to_field
+                else:
+                    headers[header] = None
+
+            # Parse CSV rows into a list of dictionaries mapped from the column headers.
+            for i, row in enumerate(reader, start=1):
+                if len(row) != len(headers):
+                    raise forms.ValidationError(
+                        f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
+                    )
+                row = [col.strip() for col in row]
+                record = dict(zip(headers.keys(), row))
+                records.append(record)
 
         return headers, records
 
     def validate(self, value):
         headers, records = value
+        if not headers and not records:
+            return value
 
         # Validate provided column headers
         for field, to_field in headers.items():