utils.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import re
  2. from django import forms
  3. from django.conf import settings
  4. from django.forms.models import fields_for_model
  5. from utilities.choices import unpack_grouped_choices
  6. from utilities.querysets import RestrictedQuerySet
  7. from .constants import *
  8. __all__ = (
  9. 'add_blank_choice',
  10. 'expand_alphanumeric_pattern',
  11. 'expand_ipaddress_pattern',
  12. 'form_from_model',
  13. 'get_selected_values',
  14. 'parse_alphanumeric_range',
  15. 'parse_numeric_range',
  16. 'restrict_form_fields',
  17. 'parse_csv',
  18. 'validate_csv',
  19. )
  20. def parse_numeric_range(string, base=10):
  21. """
  22. Expand a numeric range (continuous or not) into a decimal or
  23. hexadecimal list, as specified by the base parameter
  24. '0-3,5' => [0, 1, 2, 3, 5]
  25. '2,8-b,d,f' => [2, 8, 9, a, b, d, f]
  26. """
  27. values = list()
  28. for dash_range in string.split(','):
  29. try:
  30. begin, end = dash_range.split('-')
  31. except ValueError:
  32. begin, end = dash_range, dash_range
  33. try:
  34. begin, end = int(begin.strip(), base=base), int(end.strip(), base=base) + 1
  35. except ValueError:
  36. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  37. values.extend(range(begin, end))
  38. return list(set(values))
  39. def parse_alphanumeric_range(string):
  40. """
  41. Expand an alphanumeric range (continuous or not) into a list.
  42. 'a-d,f' => [a, b, c, d, f]
  43. '0-3,a-d' => [0, 1, 2, 3, a, b, c, d]
  44. """
  45. values = []
  46. for dash_range in string.split(','):
  47. try:
  48. begin, end = dash_range.split('-')
  49. vals = begin + end
  50. # Break out of loop if there's an invalid pattern to return an error
  51. if (not (vals.isdigit() or vals.isalpha())) or (vals.isalpha() and not (vals.isupper() or vals.islower())):
  52. return []
  53. except ValueError:
  54. begin, end = dash_range, dash_range
  55. if begin.isdigit() and end.isdigit():
  56. for n in list(range(int(begin), int(end) + 1)):
  57. values.append(n)
  58. else:
  59. # Value-based
  60. if begin == end:
  61. values.append(begin)
  62. # Range-based
  63. else:
  64. # Not a valid range (more than a single character)
  65. if not len(begin) == len(end) == 1:
  66. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  67. for n in list(range(ord(begin), ord(end) + 1)):
  68. values.append(chr(n))
  69. return values
  70. def expand_alphanumeric_pattern(string):
  71. """
  72. Expand an alphabetic pattern into a list of strings.
  73. """
  74. lead, pattern, remnant = re.split(ALPHANUMERIC_EXPANSION_PATTERN, string, maxsplit=1)
  75. parsed_range = parse_alphanumeric_range(pattern)
  76. for i in parsed_range:
  77. if re.search(ALPHANUMERIC_EXPANSION_PATTERN, remnant):
  78. for string in expand_alphanumeric_pattern(remnant):
  79. yield "{}{}{}".format(lead, i, string)
  80. else:
  81. yield "{}{}{}".format(lead, i, remnant)
  82. def expand_ipaddress_pattern(string, family):
  83. """
  84. Expand an IP address pattern into a list of strings. Examples:
  85. '192.0.2.[1,2,100-250]/24' => ['192.0.2.1/24', '192.0.2.2/24', '192.0.2.100/24' ... '192.0.2.250/24']
  86. '2001:db8:0:[0,fd-ff]::/64' => ['2001:db8:0:0::/64', '2001:db8:0:fd::/64', ... '2001:db8:0:ff::/64']
  87. """
  88. if family not in [4, 6]:
  89. raise Exception("Invalid IP address family: {}".format(family))
  90. if family == 4:
  91. regex = IP4_EXPANSION_PATTERN
  92. base = 10
  93. else:
  94. regex = IP6_EXPANSION_PATTERN
  95. base = 16
  96. lead, pattern, remnant = re.split(regex, string, maxsplit=1)
  97. parsed_range = parse_numeric_range(pattern, base)
  98. for i in parsed_range:
  99. if re.search(regex, remnant):
  100. for string in expand_ipaddress_pattern(remnant, family):
  101. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), string])
  102. else:
  103. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), remnant])
  104. def get_selected_values(form, field_name):
  105. """
  106. Return the list of selected human-friendly values for a form field
  107. """
  108. if not hasattr(form, 'cleaned_data'):
  109. form.is_valid()
  110. filter_data = form.cleaned_data.get(field_name)
  111. field = form.fields[field_name]
  112. # Non-selection field
  113. if not hasattr(field, 'choices'):
  114. return [str(filter_data)]
  115. # Model choice field
  116. if type(field.choices) is forms.models.ModelChoiceIterator:
  117. # If this is a single-choice field, wrap its value in a list
  118. if not hasattr(filter_data, '__iter__'):
  119. values = [filter_data]
  120. else:
  121. values = filter_data
  122. else:
  123. # Static selection field
  124. choices = unpack_grouped_choices(field.choices)
  125. if type(filter_data) not in (list, tuple):
  126. filter_data = [filter_data] # Ensure filter data is iterable
  127. values = [
  128. label for value, label in choices if str(value) in filter_data or None in filter_data
  129. ]
  130. # If the field has a `null_option` attribute set and it is selected,
  131. # add it to the field's grouped choices.
  132. if getattr(field, 'null_option', None) and None in filter_data:
  133. values.remove(None)
  134. values.insert(0, field.null_option)
  135. return values
  136. def add_blank_choice(choices):
  137. """
  138. Add a blank choice to the beginning of a choices list.
  139. """
  140. return ((None, '---------'),) + tuple(choices)
  141. def form_from_model(model, fields):
  142. """
  143. Return a Form class with the specified fields derived from a model. This is useful when we need a form to be used
  144. for creating objects, but want to avoid the model's validation (e.g. for bulk create/edit functions). All fields
  145. are marked as not required.
  146. """
  147. form_fields = fields_for_model(model, fields=fields)
  148. for field in form_fields.values():
  149. field.required = False
  150. return type('FormFromModel', (forms.Form,), form_fields)
  151. def restrict_form_fields(form, user, action='view'):
  152. """
  153. Restrict all form fields which reference a RestrictedQuerySet. This ensures that users see only permitted objects
  154. as available choices.
  155. """
  156. for field in form.fields.values():
  157. if hasattr(field, 'queryset') and issubclass(field.queryset.__class__, RestrictedQuerySet):
  158. field.queryset = field.queryset.restrict(user, action)
  159. def parse_csv(reader):
  160. """
  161. Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
  162. if the records are formatted incorrectly. Return headers and records as a tuple.
  163. """
  164. records = []
  165. headers = {}
  166. # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
  167. # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
  168. # `site.slug` header, to indicate the related site is being referenced by its slug.
  169. for header in next(reader):
  170. if '.' in header:
  171. field, to_field = header.split('.', 1)
  172. headers[field] = to_field
  173. else:
  174. headers[header] = None
  175. # Parse CSV rows into a list of dictionaries mapped from the column headers.
  176. for i, row in enumerate(reader, start=1):
  177. if len(row) != len(headers):
  178. raise forms.ValidationError(
  179. f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
  180. )
  181. row = [col.strip() for col in row]
  182. record = dict(zip(headers.keys(), row))
  183. records.append(record)
  184. return headers, records
  185. def validate_csv(headers, fields, required_fields):
  186. """
  187. Validate that parsed csv data conforms to the object's available fields. Raise validation errors
  188. if parsed csv data contains invalid headers or does not contain required headers.
  189. """
  190. # Validate provided column headers
  191. for field, to_field in headers.items():
  192. if field not in fields:
  193. raise forms.ValidationError(f'Unexpected column header "{field}" found.')
  194. if to_field and not hasattr(fields[field], 'to_field_name'):
  195. raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
  196. if to_field and not hasattr(fields[field].queryset.model, to_field):
  197. raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
  198. # Validate required fields
  199. for f in required_fields:
  200. if f not in headers:
  201. raise forms.ValidationError(f'Required column header "{f}" not found.')