Jeremy Stretch 5 лет назад
Родитель
Сommit
61ae4be16a
3 измененных файлов с 100 добавлено и 20 удалено
  1. 15 12
      netbox/utilities/forms.py
  2. 84 0
      netbox/utilities/tests/test_forms.py
  3. 1 8
      netbox/utilities/views.py

+ 15 - 12
netbox/utilities/forms.py

@@ -405,11 +405,14 @@ class CSVDataField(forms.CharField):
     """
     widget = forms.Textarea
 
-    def __init__(self, model, fields, required_fields=None, *args, **kwargs):
+    def __init__(self, from_form, *args, **kwargs):
 
-        self.model = model
-        self.fields = fields
-        self.required_fields = required_fields or list()
+        form = from_form()
+        self.model = form.Meta.model
+        self.fields = form.fields
+        self.required_fields = [
+            name for name, field in form.fields.items() if field.required
+        ]
 
         super().__init__(*args, **kwargs)
 
@@ -417,15 +420,16 @@ class CSVDataField(forms.CharField):
         if not self.label:
             self.label = ''
         if not self.initial:
-            self.initial = ','.join(required_fields) + '\n'
+            self.initial = ','.join(self.required_fields) + '\n'
         if not self.help_text:
             self.help_text = 'Enter the list of column headers followed by one line per record to be imported, using ' \
                              'commas to separate values. Multi-line data and values containing commas may be wrapped ' \
                              'in double quotes.'
 
     def to_python(self, value):
+
         records = []
-        reader = csv.reader(StringIO(value))
+        reader = csv.reader(StringIO(value.strip()))
 
         # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
         # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
@@ -440,12 +444,11 @@ class CSVDataField(forms.CharField):
 
         # Parse CSV data
         for i, row in enumerate(reader, start=1):
-            if row:
-                if len(row) != len(headers):
-                    raise forms.ValidationError(f"Row {i}: Expected {len(headers)} columns but found {len(row)}")
-                row = [col.strip() for col in row]
-                record = dict(zip(headers.keys(), row))
-                records.append(record)
+            if len(row) != len(headers):
+                raise forms.ValidationError(f"Row {i}: Expected {len(headers)} columns but found {len(row)}")
+            row = [col.strip() for col in row]
+            record = dict(zip(headers.keys(), row))
+            records.append(record)
 
         return headers, records
 

+ 84 - 0
netbox/utilities/tests/test_forms.py

@@ -1,6 +1,8 @@
 from django import forms
 from django.test import TestCase
 
+from ipam.forms import IPAddressCSVForm
+from ipam.models import VRF
 from utilities.forms import *
 
 
@@ -281,3 +283,85 @@ class ExpandAlphanumeric(TestCase):
 
         with self.assertRaises(ValueError):
             sorted(expand_alphanumeric_pattern('r[a,,b]a'))
+
+
+class CSVDataFieldTest(TestCase):
+
+    def setUp(self):
+        self.field = CSVDataField(from_form=IPAddressCSVForm)
+
+    def test_clean(self):
+        input = """
+        address,status,vrf
+        192.0.2.1/32,Active,Test VRF
+        """
+        output = (
+            {'address': None, 'status': None, 'vrf': None},
+            [{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': 'Test VRF'}]
+        )
+        self.assertEqual(self.field.clean(input), output)
+
+    def test_clean_invalid_header(self):
+        input = """
+        address,status,vrf,xxx
+        192.0.2.1/32,Active,Test VRF,123
+        """
+        with self.assertRaises(forms.ValidationError):
+            self.field.clean(input)
+
+    def test_clean_missing_required_header(self):
+        input = """
+        status,vrf
+        Active,Test VRF
+        """
+        with self.assertRaises(forms.ValidationError):
+            self.field.clean(input)
+
+    def test_clean_default_to_field(self):
+        input = """
+        address,status,vrf.name
+        192.0.2.1/32,Active,Test VRF
+        """
+        output = (
+            {'address': None, 'status': None, 'vrf': 'name'},
+            [{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': 'Test VRF'}]
+        )
+        self.assertEqual(self.field.clean(input), output)
+
+    def test_clean_pk_to_field(self):
+        input = """
+        address,status,vrf.pk
+        192.0.2.1/32,Active,123
+        """
+        output = (
+            {'address': None, 'status': None, 'vrf': 'pk'},
+            [{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': '123'}]
+        )
+        self.assertEqual(self.field.clean(input), output)
+
+    def test_clean_custom_to_field(self):
+        input = """
+        address,status,vrf.rd
+        192.0.2.1/32,Active,123:456
+        """
+        output = (
+            {'address': None, 'status': None, 'vrf': 'rd'},
+            [{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': '123:456'}]
+        )
+        self.assertEqual(self.field.clean(input), output)
+
+    def test_clean_invalid_to_field(self):
+        input = """
+        address,status,vrf.xxx
+        192.0.2.1/32,Active,123:456
+        """
+        with self.assertRaises(forms.ValidationError):
+            self.field.clean(input)
+
+    def test_clean_to_field_on_non_object(self):
+        input = """
+        address,status.foo,vrf
+        192.0.2.1/32,Bar,Test VRF
+        """
+        with self.assertRaises(forms.ValidationError):
+            self.field.clean(input)

+ 1 - 8
netbox/utilities/views.py

@@ -557,16 +557,9 @@ class BulkImportView(GetReturnURLMixin, View):
 
     def _import_form(self, *args, **kwargs):
 
-        fields = self.model_form().fields
-        required_fields = [
-            name for name, field in self.model_form().fields.items() if field.required
-        ]
-
         class ImportForm(BootstrapMixin, Form):
             csv = CSVDataField(
-                model=self.model_form.Meta.model,
-                fields=fields,
-                required_fields=required_fields,
+                from_form=self.model_form,
                 widget=Textarea(attrs=self.widget_attrs)
             )