Explorar o código

moved duplicated code in CSV Fields into functions in forms/utils.py

Alyssa Bigley %!s(int64=4) %!d(string=hai) anos
pai
achega
0a661596b3
Modificáronse 2 ficheiros con 59 adicións e 74 borrados
  1. 6 74
      netbox/utilities/forms/fields.py
  2. 53 0
      netbox/utilities/forms/utils.py

+ 6 - 74
netbox/utilities/forms/fields.py

@@ -17,7 +17,7 @@ from utilities.utils import content_type_name
 from utilities.validators import EnhancedURLValidator
 from . import widgets
 from .constants import *
-from .utils import expand_alphanumeric_pattern, expand_ipaddress_pattern
+from .utils import expand_alphanumeric_pattern, expand_ipaddress_pattern, parse_csv, validate_csv
 
 __all__ = (
     'CommentField',
@@ -175,49 +175,13 @@ class CSVDataField(forms.CharField):
                              'in double quotes.'
 
     def to_python(self, value):
-
-        records = []
         reader = csv.reader(StringIO(value.strip()))
 
-        # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
-        # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
-        # `site.slug` header, to indicate the related site is being referenced by its slug.
-        headers = {}
-        for header in next(reader):
-            if '.' in header:
-                field, to_field = header.split('.', 1)
-                headers[field] = to_field
-            else:
-                headers[header] = None
-
-        # Parse CSV rows into a list of dictionaries mapped from the column headers.
-        for i, row in enumerate(reader, start=1):
-            if len(row) != len(headers):
-                raise forms.ValidationError(
-                    f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
-                )
-            row = [col.strip() for col in row]
-            record = dict(zip(headers.keys(), row))
-            records.append(record)
-
-        return headers, records
+        return parse_csv(reader)
 
     def validate(self, value):
         headers, records = value
-
-        # Validate provided column headers
-        for field, to_field in headers.items():
-            if field not in self.fields:
-                raise forms.ValidationError(f'Unexpected column header "{field}" found.')
-            if to_field and not hasattr(self.fields[field], 'to_field_name'):
-                raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
-            if to_field and not hasattr(self.fields[field].queryset.model, to_field):
-                raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
-
-        # Validate required fields
-        for f in self.required_fields:
-            if f not in headers:
-                raise forms.ValidationError(f'Required column header "{f}" not found.')
+        validate_csv(headers, self.fields, self.required_fields)
 
         return value
 
@@ -244,34 +208,14 @@ class CSVFileField(forms.FileField):
         super().__init__(*args, **kwargs)
 
     def to_python(self, file):
-
-        records = []
         if file:
             csv_str = file.read().decode('utf-8')
             reader = csv.reader(csv_str.splitlines())
 
-        # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
-        # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
-        # `site.slug` header, to indicate the related site is being referenced by its slug.
-
         headers = {}
+        records = []
         if file:
-            for header in next(reader):
-                if '.' in header:
-                    field, to_field = header.split('.', 1)
-                    headers[field] = to_field
-                else:
-                    headers[header] = None
-
-            # Parse CSV rows into a list of dictionaries mapped from the column headers.
-            for i, row in enumerate(reader, start=1):
-                if len(row) != len(headers):
-                    raise forms.ValidationError(
-                        f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
-                    )
-                row = [col.strip() for col in row]
-                record = dict(zip(headers.keys(), row))
-                records.append(record)
+            headers, records = parse_csv(reader)
 
         return headers, records
 
@@ -280,19 +224,7 @@ class CSVFileField(forms.FileField):
         if not headers and not records:
             return value
 
-        # Validate provided column headers
-        for field, to_field in headers.items():
-            if field not in self.fields:
-                raise forms.ValidationError(f'Unexpected column header "{field}" found.')
-            if to_field and not hasattr(self.fields[field], 'to_field_name'):
-                raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
-            if to_field and not hasattr(self.fields[field].queryset.model, to_field):
-                raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
-
-        # Validate required fields
-        for f in self.required_fields:
-            if f not in headers:
-                raise forms.ValidationError(f'Required column header "{f}" not found.')
+        validate_csv(headers, self.fields, self.required_fields)
 
         return value
 

+ 53 - 0
netbox/utilities/forms/utils.py

@@ -14,6 +14,8 @@ __all__ = (
     'parse_alphanumeric_range',
     'parse_numeric_range',
     'restrict_form_fields',
+    'parse_csv',
+    'validate_csv',
 )
 
 
@@ -134,3 +136,54 @@ def restrict_form_fields(form, user, action='view'):
     for field in form.fields.values():
         if hasattr(field, 'queryset') and issubclass(field.queryset.__class__, RestrictedQuerySet):
             field.queryset = field.queryset.restrict(user, action)
+
+
+def parse_csv(reader):
+    """
+    Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
+    if the records are formatted incorrectly. Return headers and records as a tuple.
+    """
+    records = []
+    headers = {}
+
+    # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
+    # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
+    # `site.slug` header, to indicate the related site is being referenced by its slug.
+
+    for header in next(reader):
+        if '.' in header:
+            field, to_field = header.split('.', 1)
+            headers[field] = to_field
+        else:
+            headers[header] = None
+
+    # Parse CSV rows into a list of dictionaries mapped from the column headers.
+    for i, row in enumerate(reader, start=1):
+        if len(row) != len(headers):
+            raise forms.ValidationError(
+                f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
+            )
+        row = [col.strip() for col in row]
+        record = dict(zip(headers.keys(), row))
+        records.append(record)
+    return headers, records
+
+
+def validate_csv(headers, fields, required_fields):
+    """
+    Validate that parsed csv data conforms to the object's available fields. Raise validation errors
+    if parsed csv data contains invalid headers or does not contain required headers.
+    """
+    # Validate provided column headers
+    for field, to_field in headers.items():
+        if field not in fields:
+            raise forms.ValidationError(f'Unexpected column header "{field}" found.')
+        if to_field and not hasattr(fields[field], 'to_field_name'):
+            raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
+        if to_field and not hasattr(fields[field].queryset.model, to_field):
+            raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
+
+    # Validate required fields
+    for f in required_fields:
+        if f not in headers:
+            raise forms.ValidationError(f'Required column header "{f}" not found.')