utils.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import re
  2. from django import forms
  3. from django.forms.models import fields_for_model
  4. from utilities.querysets import RestrictedQuerySet
  5. from .constants import *
  6. __all__ = (
  7. 'add_blank_choice',
  8. 'expand_alphanumeric_pattern',
  9. 'expand_ipaddress_pattern',
  10. 'form_from_model',
  11. 'parse_alphanumeric_range',
  12. 'parse_numeric_range',
  13. 'restrict_form_fields',
  14. 'parse_csv',
  15. 'validate_csv',
  16. )
  17. def parse_numeric_range(string, base=10):
  18. """
  19. Expand a numeric range (continuous or not) into a decimal or
  20. hexadecimal list, as specified by the base parameter
  21. '0-3,5' => [0, 1, 2, 3, 5]
  22. '2,8-b,d,f' => [2, 8, 9, a, b, d, f]
  23. """
  24. values = list()
  25. for dash_range in string.split(','):
  26. try:
  27. begin, end = dash_range.split('-')
  28. except ValueError:
  29. begin, end = dash_range, dash_range
  30. try:
  31. begin, end = int(begin.strip(), base=base), int(end.strip(), base=base) + 1
  32. except ValueError:
  33. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  34. values.extend(range(begin, end))
  35. return list(set(values))
  36. def parse_alphanumeric_range(string):
  37. """
  38. Expand an alphanumeric range (continuous or not) into a list.
  39. 'a-d,f' => [a, b, c, d, f]
  40. '0-3,a-d' => [0, 1, 2, 3, a, b, c, d]
  41. """
  42. values = []
  43. for dash_range in string.split(','):
  44. try:
  45. begin, end = dash_range.split('-')
  46. vals = begin + end
  47. # Break out of loop if there's an invalid pattern to return an error
  48. if (not (vals.isdigit() or vals.isalpha())) or (vals.isalpha() and not (vals.isupper() or vals.islower())):
  49. return []
  50. except ValueError:
  51. begin, end = dash_range, dash_range
  52. if begin.isdigit() and end.isdigit():
  53. for n in list(range(int(begin), int(end) + 1)):
  54. values.append(n)
  55. else:
  56. # Value-based
  57. if begin == end:
  58. values.append(begin)
  59. # Range-based
  60. else:
  61. # Not a valid range (more than a single character)
  62. if not len(begin) == len(end) == 1:
  63. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  64. for n in list(range(ord(begin), ord(end) + 1)):
  65. values.append(chr(n))
  66. return values
  67. def expand_alphanumeric_pattern(string):
  68. """
  69. Expand an alphabetic pattern into a list of strings.
  70. """
  71. lead, pattern, remnant = re.split(ALPHANUMERIC_EXPANSION_PATTERN, string, maxsplit=1)
  72. parsed_range = parse_alphanumeric_range(pattern)
  73. for i in parsed_range:
  74. if re.search(ALPHANUMERIC_EXPANSION_PATTERN, remnant):
  75. for string in expand_alphanumeric_pattern(remnant):
  76. yield "{}{}{}".format(lead, i, string)
  77. else:
  78. yield "{}{}{}".format(lead, i, remnant)
  79. def expand_ipaddress_pattern(string, family):
  80. """
  81. Expand an IP address pattern into a list of strings. Examples:
  82. '192.0.2.[1,2,100-250]/24' => ['192.0.2.1/24', '192.0.2.2/24', '192.0.2.100/24' ... '192.0.2.250/24']
  83. '2001:db8:0:[0,fd-ff]::/64' => ['2001:db8:0:0::/64', '2001:db8:0:fd::/64', ... '2001:db8:0:ff::/64']
  84. """
  85. if family not in [4, 6]:
  86. raise Exception("Invalid IP address family: {}".format(family))
  87. if family == 4:
  88. regex = IP4_EXPANSION_PATTERN
  89. base = 10
  90. else:
  91. regex = IP6_EXPANSION_PATTERN
  92. base = 16
  93. lead, pattern, remnant = re.split(regex, string, maxsplit=1)
  94. parsed_range = parse_numeric_range(pattern, base)
  95. for i in parsed_range:
  96. if re.search(regex, remnant):
  97. for string in expand_ipaddress_pattern(remnant, family):
  98. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), string])
  99. else:
  100. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), remnant])
  101. def add_blank_choice(choices):
  102. """
  103. Add a blank choice to the beginning of a choices list.
  104. """
  105. return ((None, '---------'),) + tuple(choices)
  106. def form_from_model(model, fields):
  107. """
  108. Return a Form class with the specified fields derived from a model. This is useful when we need a form to be used
  109. for creating objects, but want to avoid the model's validation (e.g. for bulk create/edit functions). All fields
  110. are marked as not required.
  111. """
  112. form_fields = fields_for_model(model, fields=fields)
  113. for field in form_fields.values():
  114. field.required = False
  115. return type('FormFromModel', (forms.Form,), form_fields)
  116. def restrict_form_fields(form, user, action='view'):
  117. """
  118. Restrict all form fields which reference a RestrictedQuerySet. This ensures that users see only permitted objects
  119. as available choices.
  120. """
  121. for field in form.fields.values():
  122. if hasattr(field, 'queryset') and issubclass(field.queryset.__class__, RestrictedQuerySet):
  123. field.queryset = field.queryset.restrict(user, action)
  124. def parse_csv(reader):
  125. """
  126. Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
  127. if the records are formatted incorrectly. Return headers and records as a tuple.
  128. """
  129. records = []
  130. headers = {}
  131. # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
  132. # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
  133. # `site.slug` header, to indicate the related site is being referenced by its slug.
  134. for header in next(reader):
  135. if '.' in header:
  136. field, to_field = header.split('.', 1)
  137. headers[field] = to_field
  138. else:
  139. headers[header] = None
  140. # Parse CSV rows into a list of dictionaries mapped from the column headers.
  141. for i, row in enumerate(reader, start=1):
  142. if len(row) != len(headers):
  143. raise forms.ValidationError(
  144. f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
  145. )
  146. row = [col.strip() for col in row]
  147. record = dict(zip(headers.keys(), row))
  148. records.append(record)
  149. return headers, records
  150. def validate_csv(headers, fields, required_fields):
  151. """
  152. Validate that parsed csv data conforms to the object's available fields. Raise validation errors
  153. if parsed csv data contains invalid headers or does not contain required headers.
  154. """
  155. # Validate provided column headers
  156. for field, to_field in headers.items():
  157. if field not in fields:
  158. raise forms.ValidationError(f'Unexpected column header "{field}" found.')
  159. if to_field and not hasattr(fields[field], 'to_field_name'):
  160. raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
  161. if to_field and not hasattr(fields[field].queryset.model, to_field):
  162. raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
  163. # Validate required fields
  164. for f in required_fields:
  165. if f not in headers:
  166. raise forms.ValidationError(f'Required column header "{f}" not found.')