utils.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import re
  2. from django import forms
  3. from django.conf import settings
  4. from django.forms.models import fields_for_model
  5. from utilities.choices import unpack_grouped_choices
  6. from utilities.querysets import RestrictedQuerySet
  7. from .constants import *
  8. __all__ = (
  9. 'add_blank_choice',
  10. 'expand_alphanumeric_pattern',
  11. 'expand_ipaddress_pattern',
  12. 'form_from_model',
  13. 'get_selected_values',
  14. 'parse_alphanumeric_range',
  15. 'parse_numeric_range',
  16. 'restrict_form_fields',
  17. 'parse_csv',
  18. 'validate_csv',
  19. )
  20. def parse_numeric_range(string, base=10):
  21. """
  22. Expand a numeric range (continuous or not) into a decimal or
  23. hexadecimal list, as specified by the base parameter
  24. '0-3,5' => [0, 1, 2, 3, 5]
  25. '2,8-b,d,f' => [2, 8, 9, a, b, d, f]
  26. """
  27. values = list()
  28. for dash_range in string.split(','):
  29. try:
  30. begin, end = dash_range.split('-')
  31. except ValueError:
  32. begin, end = dash_range, dash_range
  33. try:
  34. begin, end = int(begin.strip(), base=base), int(end.strip(), base=base) + 1
  35. except ValueError:
  36. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  37. values.extend(range(begin, end))
  38. return list(set(values))
  39. def parse_alphanumeric_range(string):
  40. """
  41. Expand an alphanumeric range (continuous or not) into a list.
  42. 'a-d,f' => [a, b, c, d, f]
  43. '0-3,a-d' => [0, 1, 2, 3, a, b, c, d]
  44. """
  45. values = []
  46. for dash_range in string.split(','):
  47. try:
  48. begin, end = dash_range.split('-')
  49. vals = begin + end
  50. # Break out of loop if there's an invalid pattern to return an error
  51. if (not (vals.isdigit() or vals.isalpha())) or (vals.isalpha() and not (vals.isupper() or vals.islower())):
  52. return []
  53. except ValueError:
  54. begin, end = dash_range, dash_range
  55. if begin.isdigit() and end.isdigit():
  56. for n in list(range(int(begin), int(end) + 1)):
  57. values.append(n)
  58. else:
  59. # Value-based
  60. if begin == end:
  61. values.append(begin)
  62. # Range-based
  63. else:
  64. # Not a valid range (more than a single character)
  65. if not len(begin) == len(end) == 1:
  66. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  67. for n in list(range(ord(begin), ord(end) + 1)):
  68. values.append(chr(n))
  69. return values
  70. def expand_alphanumeric_pattern(string):
  71. """
  72. Expand an alphabetic pattern into a list of strings.
  73. """
  74. lead, pattern, remnant = re.split(ALPHANUMERIC_EXPANSION_PATTERN, string, maxsplit=1)
  75. parsed_range = parse_alphanumeric_range(pattern)
  76. for i in parsed_range:
  77. if re.search(ALPHANUMERIC_EXPANSION_PATTERN, remnant):
  78. for string in expand_alphanumeric_pattern(remnant):
  79. yield "{}{}{}".format(lead, i, string)
  80. else:
  81. yield "{}{}{}".format(lead, i, remnant)
  82. def expand_ipaddress_pattern(string, family):
  83. """
  84. Expand an IP address pattern into a list of strings. Examples:
  85. '192.0.2.[1,2,100-250]/24' => ['192.0.2.1/24', '192.0.2.2/24', '192.0.2.100/24' ... '192.0.2.250/24']
  86. '2001:db8:0:[0,fd-ff]::/64' => ['2001:db8:0:0::/64', '2001:db8:0:fd::/64', ... '2001:db8:0:ff::/64']
  87. """
  88. if family not in [4, 6]:
  89. raise Exception("Invalid IP address family: {}".format(family))
  90. if family == 4:
  91. regex = IP4_EXPANSION_PATTERN
  92. base = 10
  93. else:
  94. regex = IP6_EXPANSION_PATTERN
  95. base = 16
  96. lead, pattern, remnant = re.split(regex, string, maxsplit=1)
  97. parsed_range = parse_numeric_range(pattern, base)
  98. for i in parsed_range:
  99. if re.search(regex, remnant):
  100. for string in expand_ipaddress_pattern(remnant, family):
  101. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), string])
  102. else:
  103. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), remnant])
  104. def get_selected_values(form, field_name):
  105. """
  106. Return the list of selected human-friendly values for a form field
  107. """
  108. if not hasattr(form, 'cleaned_data'):
  109. form.is_valid()
  110. filter_data = form.cleaned_data.get(field_name)
  111. field = form.fields[field_name]
  112. # Selection field
  113. if hasattr(field, 'choices'):
  114. try:
  115. choices = unpack_grouped_choices(field.choices)
  116. if hasattr(field, 'null_option'):
  117. # If the field has a `null_option` attribute set and it is selected,
  118. # add it to the field's grouped choices.
  119. if field.null_option is not None and None in filter_data:
  120. choices.append((settings.FILTERS_NULL_CHOICE_VALUE, field.null_option))
  121. return [
  122. label for value, label in choices if str(value) in filter_data or None in filter_data
  123. ]
  124. except TypeError:
  125. # Field uses dynamic choices. Show all that have been populated.
  126. return [
  127. subwidget.choice_label for subwidget in form[field_name].subwidgets
  128. ]
  129. # Non-selection field
  130. return [str(filter_data)]
  131. def add_blank_choice(choices):
  132. """
  133. Add a blank choice to the beginning of a choices list.
  134. """
  135. return ((None, '---------'),) + tuple(choices)
  136. def form_from_model(model, fields):
  137. """
  138. Return a Form class with the specified fields derived from a model. This is useful when we need a form to be used
  139. for creating objects, but want to avoid the model's validation (e.g. for bulk create/edit functions). All fields
  140. are marked as not required.
  141. """
  142. form_fields = fields_for_model(model, fields=fields)
  143. for field in form_fields.values():
  144. field.required = False
  145. return type('FormFromModel', (forms.Form,), form_fields)
  146. def restrict_form_fields(form, user, action='view'):
  147. """
  148. Restrict all form fields which reference a RestrictedQuerySet. This ensures that users see only permitted objects
  149. as available choices.
  150. """
  151. for field in form.fields.values():
  152. if hasattr(field, 'queryset') and issubclass(field.queryset.__class__, RestrictedQuerySet):
  153. field.queryset = field.queryset.restrict(user, action)
  154. def parse_csv(reader):
  155. """
  156. Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
  157. if the records are formatted incorrectly. Return headers and records as a tuple.
  158. """
  159. records = []
  160. headers = {}
  161. # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
  162. # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
  163. # `site.slug` header, to indicate the related site is being referenced by its slug.
  164. for header in next(reader):
  165. if '.' in header:
  166. field, to_field = header.split('.', 1)
  167. headers[field] = to_field
  168. else:
  169. headers[header] = None
  170. # Parse CSV rows into a list of dictionaries mapped from the column headers.
  171. for i, row in enumerate(reader, start=1):
  172. if len(row) != len(headers):
  173. raise forms.ValidationError(
  174. f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
  175. )
  176. row = [col.strip() for col in row]
  177. record = dict(zip(headers.keys(), row))
  178. records.append(record)
  179. return headers, records
  180. def validate_csv(headers, fields, required_fields):
  181. """
  182. Validate that parsed csv data conforms to the object's available fields. Raise validation errors
  183. if parsed csv data contains invalid headers or does not contain required headers.
  184. """
  185. # Validate provided column headers
  186. for field, to_field in headers.items():
  187. if field not in fields:
  188. raise forms.ValidationError(f'Unexpected column header "{field}" found.')
  189. if to_field and not hasattr(fields[field], 'to_field_name'):
  190. raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
  191. if to_field and not hasattr(fields[field].queryset.model, to_field):
  192. raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
  193. # Validate required fields
  194. for f in required_fields:
  195. if f not in headers:
  196. raise forms.ValidationError(f'Required column header "{f}" not found.')