Procházet zdrojové kódy

Merge pull request #6561 from abigley/csv_feature

CSV file import
Jeremy Stretch před 4 roky
rodič
revize
ea0de4b01d

+ 17 - 3
netbox/netbox/views/generic.py

@@ -20,7 +20,7 @@ from extras.models import CustomField, ExportTemplate
 from utilities.error_handlers import handle_protectederror
 from utilities.exceptions import AbortTransaction, PermissionsViolation
 from utilities.forms import (
-    BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, ImportForm, TableConfigForm, restrict_form_fields,
+    BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, ImportForm, TableConfigForm, restrict_form_fields, CSVFileField
 )
 from utilities.permissions import get_permission_for_model
 from utilities.tables import paginate_table
@@ -667,6 +667,14 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
                 from_form=self.model_form,
                 widget=Textarea(attrs=self.widget_attrs)
             )
+            csv_file = CSVFileField(
+                label="CSV file",
+                from_form=self.model_form,
+                required=False
+            )
+
+            def used_both_csv_fields(self):
+                return self.cleaned_data['csv_file'][1] and self.cleaned_data['csv'][1]
 
         return ImportForm(*args, **kwargs)
 
@@ -691,15 +699,21 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
     def post(self, request):
         logger = logging.getLogger('netbox.views.BulkImportView')
         new_objs = []
-        form = self._import_form(request.POST)
+        form = self._import_form(request.POST, request.FILES)
 
         if form.is_valid():
             logger.debug("Form validation was successful")
 
             try:
+                if form.used_both_csv_fields():
+                    form.add_error('csv_file', "Choose one of two import methods")
+                    raise ValidationError("")
                 # Iterate through CSV data and bind each row to a new model form instance.
                 with transaction.atomic():
-                    headers, records = form.cleaned_data['csv']
+                    if request.FILES:
+                        headers, records = form.cleaned_data['csv_file']
+                    else:
+                        headers, records = form.cleaned_data['csv']
                     for row, data in enumerate(records, start=1):
                         obj_form = self.model_form(data, headers=headers)
                         restrict_form_fields(obj_form, request.user)

+ 1 - 1
netbox/templates/generic/object_bulk_import.html

@@ -20,7 +20,7 @@
             </ul>
             <div class="tab-content">
                 <div role="tabpanel" class="tab-pane active" id="csv">
-                    <form action="" method="post" class="form">
+                    <form action="" method="post" class="form" enctype="multipart/form-data">
                         {% csrf_token %}
                         {% render_form form %}
                         <div class="form-group">

+ 43 - 35
netbox/utilities/forms/fields.py

@@ -17,7 +17,7 @@ from utilities.utils import content_type_name
 from utilities.validators import EnhancedURLValidator
 from . import widgets
 from .constants import *
-from .utils import expand_alphanumeric_pattern, expand_ipaddress_pattern
+from .utils import expand_alphanumeric_pattern, expand_ipaddress_pattern, parse_csv, validate_csv
 
 __all__ = (
     'CommentField',
@@ -26,6 +26,7 @@ __all__ = (
     'CSVChoiceField',
     'CSVContentTypeField',
     'CSVDataField',
+    'CSVFileField',
     'CSVModelChoiceField',
     'CSVTypedChoiceField',
     'DynamicModelChoiceField',
@@ -174,49 +175,56 @@ class CSVDataField(forms.CharField):
                              'in double quotes.'
 
     def to_python(self, value):
-
-        records = []
         reader = csv.reader(StringIO(value.strip()))
 
-        # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
-        # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
-        # `site.slug` header, to indicate the related site is being referenced by its slug.
+        return parse_csv(reader)
+
+    def validate(self, value):
+        headers, records = value
+        validate_csv(headers, self.fields, self.required_fields)
+
+        return value
+
+
+class CSVFileField(forms.FileField):
+    """
+    A FileField (rendered as a file input button) which accepts a file containing CSV-formatted data. It returns
+    data as a two-tuple: The first item is a dictionary of column headers, mapping field names to the attribute
+    by which they match a related object (where applicable). The second item is a list of dictionaries, each
+    representing a discrete row of CSV data.
+
+    :param from_form: The form from which the field derives its validation rules.
+    """
+
+    def __init__(self, from_form, *args, **kwargs):
+
+        form = from_form()
+        self.model = form.Meta.model
+        self.fields = form.fields
+        self.required_fields = [
+            name for name, field in form.fields.items() if field.required
+        ]
+
+        super().__init__(*args, **kwargs)
+
+    def to_python(self, file):
+        if file:
+            csv_str = file.read().decode('utf-8')
+            reader = csv.reader(csv_str.splitlines())
+
         headers = {}
-        for header in next(reader):
-            if '.' in header:
-                field, to_field = header.split('.', 1)
-                headers[field] = to_field
-            else:
-                headers[header] = None
-
-        # Parse CSV rows into a list of dictionaries mapped from the column headers.
-        for i, row in enumerate(reader, start=1):
-            if len(row) != len(headers):
-                raise forms.ValidationError(
-                    f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
-                )
-            row = [col.strip() for col in row]
-            record = dict(zip(headers.keys(), row))
-            records.append(record)
+        records = []
+        if file:
+            headers, records = parse_csv(reader)
 
         return headers, records
 
     def validate(self, value):
         headers, records = value
+        if not headers and not records:
+            return value
 
-        # Validate provided column headers
-        for field, to_field in headers.items():
-            if field not in self.fields:
-                raise forms.ValidationError(f'Unexpected column header "{field}" found.')
-            if to_field and not hasattr(self.fields[field], 'to_field_name'):
-                raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
-            if to_field and not hasattr(self.fields[field].queryset.model, to_field):
-                raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
-
-        # Validate required fields
-        for f in self.required_fields:
-            if f not in headers:
-                raise forms.ValidationError(f'Required column header "{f}" not found.')
+        validate_csv(headers, self.fields, self.required_fields)
 
         return value
 

+ 53 - 0
netbox/utilities/forms/utils.py

@@ -14,6 +14,8 @@ __all__ = (
     'parse_alphanumeric_range',
     'parse_numeric_range',
     'restrict_form_fields',
+    'parse_csv',
+    'validate_csv',
 )
 
 
@@ -134,3 +136,54 @@ def restrict_form_fields(form, user, action='view'):
     for field in form.fields.values():
         if hasattr(field, 'queryset') and issubclass(field.queryset.__class__, RestrictedQuerySet):
             field.queryset = field.queryset.restrict(user, action)
+
+
+def parse_csv(reader):
+    """
+    Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
+    if the records are formatted incorrectly. Return headers and records as a tuple.
+    """
+    records = []
+    headers = {}
+
+    # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
+    # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
+    # `site.slug` header, to indicate the related site is being referenced by its slug.
+
+    for header in next(reader):
+        if '.' in header:
+            field, to_field = header.split('.', 1)
+            headers[field] = to_field
+        else:
+            headers[header] = None
+
+    # Parse CSV rows into a list of dictionaries mapped from the column headers.
+    for i, row in enumerate(reader, start=1):
+        if len(row) != len(headers):
+            raise forms.ValidationError(
+                f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
+            )
+        row = [col.strip() for col in row]
+        record = dict(zip(headers.keys(), row))
+        records.append(record)
+    return headers, records
+
+
+def validate_csv(headers, fields, required_fields):
+    """
+    Validate that parsed csv data conforms to the object's available fields. Raise validation errors
+    if parsed csv data contains invalid headers or does not contain required headers.
+    """
+    # Validate provided column headers
+    for field, to_field in headers.items():
+        if field not in fields:
+            raise forms.ValidationError(f'Unexpected column header "{field}" found.')
+        if to_field and not hasattr(fields[field], 'to_field_name'):
+            raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
+        if to_field and not hasattr(fields[field].queryset.model, to_field):
+            raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
+
+    # Validate required fields
+    for f in required_fields:
+        if f not in headers:
+            raise forms.ValidationError(f'Required column header "{f}" not found.')