Просмотр исходного кода

CSV import implemented using CSVFileField

Alyssa Bigley 4 лет назад
Родитель
Сommit
c2b2b059e6
2 измененных файлов с 77 добавлено и 22 удалено
  1. 5 22
      netbox/netbox/views/generic.py
  2. 72 0
      netbox/utilities/forms/fields.py

+ 5 - 22
netbox/netbox/views/generic.py

@@ -1,6 +1,5 @@
 import logging
 import re
-import csv
 from copy import deepcopy
 
 from django.contrib import messages
@@ -8,7 +7,7 @@ from django.contrib.contenttypes.models import ContentType
 from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist, ValidationError
 from django.db import transaction, IntegrityError
 from django.db.models import ManyToManyField, ProtectedError
-from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput, Textarea, FileField
+from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput, Textarea
 from django.http import HttpResponse
 from django.shortcuts import get_object_or_404, redirect, render
 from django.utils.html import escape
@@ -21,7 +20,7 @@ from extras.models import CustomField, ExportTemplate
 from utilities.error_handlers import handle_protectederror
 from utilities.exceptions import AbortTransaction
 from utilities.forms import (
-    BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, ImportForm, TableConfigForm, restrict_form_fields,
+    BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, ImportForm, TableConfigForm, restrict_form_fields, CSVFileField
 )
 from utilities.permissions import get_permission_for_model
 from utilities.tables import paginate_table
@@ -666,7 +665,8 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
                 from_form=self.model_form,
                 widget=Textarea(attrs=self.widget_attrs)
             )
-            upload_csv = FileField(
+            upload_csv = CSVFileField(
+                from_form=self.model_form,
                 required=False
             )
         return ImportForm(*args, **kwargs)
@@ -701,26 +701,9 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
                 # Iterate through CSV data and bind each row to a new model form instance.
                 with transaction.atomic():
                     if request.FILES:
-                        csv_file = request.FILES["upload_csv"]
-                        csv_file.seek(0)
-                        csv_str = csv_file.read().decode('utf-8')
-                        reader = csv.reader(csv_str.splitlines())
-                        headers_list = next(reader)
-                        headers = {}
-                        for header in headers_list:
-                            headers[header] = None
-                        records = []
-                        for row in reader:
-                            row_dict = {}
-                            for i, elt in enumerate(row):
-                                if elt == '':
-                                    row_dict[headers_list[i]] = None
-                                else:
-                                    row_dict[headers_list[i]] = elt
-                            records.append(row_dict)
+                        headers, records = form.cleaned_data['upload_csv']
                     else:
                         headers, records = form.cleaned_data['csv']
-                    print("headers:", headers, "records:", records)
                     for row, data in enumerate(records, start=1):
                         obj_form = self.model_form(data, headers=headers)
                         restrict_form_fields(obj_form, request.user)

+ 72 - 0
netbox/utilities/forms/fields.py

@@ -26,6 +26,7 @@ __all__ = (
     'CSVChoiceField',
     'CSVContentTypeField',
     'CSVDataField',
+    'CSVFileField',
     'CSVModelChoiceField',
     'CSVTypedChoiceField',
     'DynamicModelChoiceField',
@@ -221,6 +222,77 @@ class CSVDataField(forms.CharField):
         return value
 
 
+class CSVFileField(forms.FileField):
+    """
+    A CharField (rendered as a Textarea) which accepts CSV-formatted data. It returns data as a two-tuple: The first
+    item is a dictionary of column headers, mapping field names to the attribute by which they match a related object
+    (where applicable). The second item is a list of dictionaries, each representing a discrete row of CSV data.
+
+    :param from_form: The form from which the field derives its validation rules.
+    """
+
+    def __init__(self, from_form, *args, **kwargs):
+
+        form = from_form()
+        self.model = form.Meta.model
+        self.fields = form.fields
+        self.required_fields = [
+            name for name, field in form.fields.items() if field.required
+        ]
+
+        super().__init__(*args, **kwargs)
+
+    def to_python(self, file):
+
+        records = []
+        file.seek(0)
+        csv_str = file.read().decode('utf-8')
+        reader = csv.reader(csv_str.splitlines())
+
+        # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
+        # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
+        # `site.slug` header, to indicate the related site is being referenced by its slug.
+
+        headers = {}
+        for header in next(reader):
+            if '.' in header:
+                field, to_field = header.split('.', 1)
+                headers[field] = to_field
+            else:
+                headers[header] = None
+
+        # Parse CSV rows into a list of dictionaries mapped from the column headers.
+        for i, row in enumerate(reader, start=1):
+            if len(row) != len(headers):
+                raise forms.ValidationError(
+                    f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
+                )
+            row = [col.strip() for col in row]
+            record = dict(zip(headers.keys(), row))
+            records.append(record)
+
+        return headers, records
+
+    def validate(self, value):
+        headers, records = value
+
+        # Validate provided column headers
+        for field, to_field in headers.items():
+            if field not in self.fields:
+                raise forms.ValidationError(f'Unexpected column header "{field}" found.')
+            if to_field and not hasattr(self.fields[field], 'to_field_name'):
+                raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
+            if to_field and not hasattr(self.fields[field].queryset.model, to_field):
+                raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
+
+        # Validate required fields
+        for f in self.required_fields:
+            if f not in headers:
+                raise forms.ValidationError(f'Required column header "{f}" not found.')
+
+        return value
+
+
 class CSVChoiceField(forms.ChoiceField):
     """
     Invert the provided set of choices to take the human-friendly label as input, and return the database value.