Просмотр исходного кода

Enable the specifcation of related objects by arbitrary attribute during CSV import

Jeremy Stretch 5 лет назад
Родитель
Сommit
34a17d4571
2 измененных файлов с 53 добавлено и 20 удалено
  1. 35 16
      netbox/utilities/forms.py
  2. 18 4
      netbox/utilities/views.py

+ 35 - 16
netbox/utilities/forms.py

@@ -405,10 +405,11 @@ class CSVDataField(forms.CharField):
     """
     widget = forms.Textarea
 
-    def __init__(self, fields, required_fields=[], *args, **kwargs):
+    def __init__(self, model, fields, required_fields=None, *args, **kwargs):
 
+        self.model = model
         self.fields = fields
-        self.required_fields = required_fields
+        self.required_fields = required_fields or list()
 
         super().__init__(*args, **kwargs)
 
@@ -423,31 +424,49 @@ class CSVDataField(forms.CharField):
                              'in double quotes.'
 
     def to_python(self, value):
-
         records = []
         reader = csv.reader(StringIO(value))
 
-        # Consume and validate the first line of CSV data as column headers
-        headers = next(reader)
-        for f in self.required_fields:
-            if f not in headers:
-                raise forms.ValidationError('Required column header "{}" not found.'.format(f))
-        for f in headers:
-            if f not in self.fields:
-                raise forms.ValidationError('Unexpected column header "{}" found.'.format(f))
+        # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
+        # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
+        # `site.slug` header, to indicate the related site is being referenced by its slug.
+        headers = {}
+        for header in next(reader):
+            if '.' in header:
+                field, to_field = header.split('.', 1)
+                headers[field] = to_field
+            else:
+                headers[header] = None
 
         # Parse CSV data
         for i, row in enumerate(reader, start=1):
             if row:
                 if len(row) != len(headers):
-                    raise forms.ValidationError(
-                        "Row {}: Expected {} columns but found {}".format(i, len(headers), len(row))
-                    )
+                    raise forms.ValidationError(f"Row {i}: Expected {len(headers)} columns but found {len(row)}")
                 row = [col.strip() for col in row]
-                record = dict(zip(headers, row))
+                record = dict(zip(headers.keys(), row))
                 records.append(record)
 
-        return records
+        return headers, records
+
+    def validate(self, value):
+        headers, records = value
+
+        # Validate provided column headers
+        for field, to_field in headers.items():
+            if field not in self.fields:
+                raise forms.ValidationError(f'Unexpected column header "{field}" found.')
+            if to_field and not hasattr(self.fields[field], 'to_field_name'):
+                raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
+            if to_field and not hasattr(self.fields[field].queryset.model, to_field):
+                raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
+
+        # Validate required fields
+        for f in self.required_fields:
+            if f not in headers:
+                raise forms.ValidationError(f'Required column header "{f}" not found.')
+
+        return value
 
 
 class CSVChoiceField(forms.ChoiceField):

+ 18 - 4
netbox/utilities/views.py

@@ -557,11 +557,18 @@ class BulkImportView(GetReturnURLMixin, View):
 
     def _import_form(self, *args, **kwargs):
 
-        fields = self.model_form().fields.keys()
-        required_fields = [name for name, field in self.model_form().fields.items() if field.required]
+        fields = self.model_form().fields
+        required_fields = [
+            name for name, field in self.model_form().fields.items() if field.required
+        ]
 
         class ImportForm(BootstrapMixin, Form):
-            csv = CSVDataField(fields=fields, required_fields=required_fields, widget=Textarea(attrs=self.widget_attrs))
+            csv = CSVDataField(
+                model=self.model_form.Meta.model,
+                fields=fields,
+                required_fields=required_fields,
+                widget=Textarea(attrs=self.widget_attrs)
+            )
 
         return ImportForm(*args, **kwargs)
 
@@ -591,8 +598,15 @@ class BulkImportView(GetReturnURLMixin, View):
             try:
                 # Iterate through CSV data and bind each row to a new model form instance.
                 with transaction.atomic():
-                    for row, data in enumerate(form.cleaned_data['csv'], start=1):
+                    headers, records = form.cleaned_data['csv']
+                    for row, data in enumerate(records, start=1):
                         obj_form = self.model_form(data)
+
+                        # Modify the model form to accommodate any customized to_field_name properties
+                        for field, to_field in headers.items():
+                            if to_field is not None:
+                                obj_form.fields[field].to_field_name = to_field
+
                         if obj_form.is_valid():
                             obj = self._save_obj(obj_form, request)
                             new_objs.append(obj)