utils.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import re
  2. from django import forms
  3. from django.forms.models import fields_for_model
  4. from utilities.choices import unpack_grouped_choices
  5. from utilities.querysets import RestrictedQuerySet
  6. from .constants import *
  7. __all__ = (
  8. 'add_blank_choice',
  9. 'expand_alphanumeric_pattern',
  10. 'expand_ipaddress_pattern',
  11. 'form_from_model',
  12. 'get_field_value',
  13. 'get_selected_values',
  14. 'parse_alphanumeric_range',
  15. 'parse_numeric_range',
  16. 'restrict_form_fields',
  17. 'parse_csv',
  18. 'validate_csv',
  19. )
  20. def parse_numeric_range(string, base=10):
  21. """
  22. Expand a numeric range (continuous or not) into a decimal or
  23. hexadecimal list, as specified by the base parameter
  24. '0-3,5' => [0, 1, 2, 3, 5]
  25. '2,8-b,d,f' => [2, 8, 9, a, b, d, f]
  26. """
  27. values = list()
  28. for dash_range in string.split(','):
  29. try:
  30. begin, end = dash_range.split('-')
  31. except ValueError:
  32. begin, end = dash_range, dash_range
  33. try:
  34. begin, end = int(begin.strip(), base=base), int(end.strip(), base=base) + 1
  35. except ValueError:
  36. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  37. values.extend(range(begin, end))
  38. return list(set(values))
  39. def parse_alphanumeric_range(string):
  40. """
  41. Expand an alphanumeric range (continuous or not) into a list.
  42. 'a-d,f' => [a, b, c, d, f]
  43. '0-3,a-d' => [0, 1, 2, 3, a, b, c, d]
  44. """
  45. values = []
  46. for dash_range in string.split(','):
  47. try:
  48. begin, end = dash_range.split('-')
  49. vals = begin + end
  50. # Break out of loop if there's an invalid pattern to return an error
  51. if (not (vals.isdigit() or vals.isalpha())) or (vals.isalpha() and not (vals.isupper() or vals.islower())):
  52. return []
  53. except ValueError:
  54. begin, end = dash_range, dash_range
  55. if begin.isdigit() and end.isdigit():
  56. if int(begin) >= int(end):
  57. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  58. for n in list(range(int(begin), int(end) + 1)):
  59. values.append(n)
  60. else:
  61. # Value-based
  62. if begin == end:
  63. values.append(begin)
  64. # Range-based
  65. else:
  66. # Not a valid range (more than a single character)
  67. if not len(begin) == len(end) == 1:
  68. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  69. if ord(begin) >= ord(end):
  70. raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
  71. for n in list(range(ord(begin), ord(end) + 1)):
  72. values.append(chr(n))
  73. return values
  74. def expand_alphanumeric_pattern(string):
  75. """
  76. Expand an alphabetic pattern into a list of strings.
  77. """
  78. lead, pattern, remnant = re.split(ALPHANUMERIC_EXPANSION_PATTERN, string, maxsplit=1)
  79. parsed_range = parse_alphanumeric_range(pattern)
  80. for i in parsed_range:
  81. if re.search(ALPHANUMERIC_EXPANSION_PATTERN, remnant):
  82. for string in expand_alphanumeric_pattern(remnant):
  83. yield "{}{}{}".format(lead, i, string)
  84. else:
  85. yield "{}{}{}".format(lead, i, remnant)
  86. def expand_ipaddress_pattern(string, family):
  87. """
  88. Expand an IP address pattern into a list of strings. Examples:
  89. '192.0.2.[1,2,100-250]/24' => ['192.0.2.1/24', '192.0.2.2/24', '192.0.2.100/24' ... '192.0.2.250/24']
  90. '2001:db8:0:[0,fd-ff]::/64' => ['2001:db8:0:0::/64', '2001:db8:0:fd::/64', ... '2001:db8:0:ff::/64']
  91. """
  92. if family not in [4, 6]:
  93. raise Exception("Invalid IP address family: {}".format(family))
  94. if family == 4:
  95. regex = IP4_EXPANSION_PATTERN
  96. base = 10
  97. else:
  98. regex = IP6_EXPANSION_PATTERN
  99. base = 16
  100. lead, pattern, remnant = re.split(regex, string, maxsplit=1)
  101. parsed_range = parse_numeric_range(pattern, base)
  102. for i in parsed_range:
  103. if re.search(regex, remnant):
  104. for string in expand_ipaddress_pattern(remnant, family):
  105. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), string])
  106. else:
  107. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), remnant])
  108. def get_field_value(form, field_name):
  109. """
  110. Return the current bound or initial value associated with a form field, prior to calling
  111. clean() for the form.
  112. """
  113. field = form.fields[field_name]
  114. if form.is_bound:
  115. if data := form.data.get(field_name):
  116. if field.valid_value(data):
  117. return data
  118. return form.get_initial_for_field(field, field_name)
  119. def get_selected_values(form, field_name):
  120. """
  121. Return the list of selected human-friendly values for a form field
  122. """
  123. if not hasattr(form, 'cleaned_data'):
  124. form.is_valid()
  125. filter_data = form.cleaned_data.get(field_name)
  126. field = form.fields[field_name]
  127. # Non-selection field
  128. if not hasattr(field, 'choices'):
  129. return [str(filter_data)]
  130. # Model choice field
  131. if type(field.choices) is forms.models.ModelChoiceIterator:
  132. # If this is a single-choice field, wrap its value in a list
  133. if not hasattr(filter_data, '__iter__'):
  134. values = [filter_data]
  135. else:
  136. values = filter_data
  137. else:
  138. # Static selection field
  139. choices = unpack_grouped_choices(field.choices)
  140. if type(filter_data) not in (list, tuple):
  141. filter_data = [filter_data] # Ensure filter data is iterable
  142. values = [
  143. label for value, label in choices if str(value) in filter_data or None in filter_data
  144. ]
  145. # If the field has a `null_option` attribute set and it is selected,
  146. # add it to the field's grouped choices.
  147. if getattr(field, 'null_option', None) and None in filter_data:
  148. values.remove(None)
  149. values.insert(0, field.null_option)
  150. return values
  151. def add_blank_choice(choices):
  152. """
  153. Add a blank choice to the beginning of a choices list.
  154. """
  155. return ((None, '---------'),) + tuple(choices)
  156. def form_from_model(model, fields):
  157. """
  158. Return a Form class with the specified fields derived from a model. This is useful when we need a form to be used
  159. for creating objects, but want to avoid the model's validation (e.g. for bulk create/edit functions). All fields
  160. are marked as not required.
  161. """
  162. form_fields = fields_for_model(model, fields=fields)
  163. for field in form_fields.values():
  164. field.required = False
  165. return type('FormFromModel', (forms.Form,), form_fields)
  166. def restrict_form_fields(form, user, action='view'):
  167. """
  168. Restrict all form fields which reference a RestrictedQuerySet. This ensures that users see only permitted objects
  169. as available choices.
  170. """
  171. for field in form.fields.values():
  172. if hasattr(field, 'queryset') and issubclass(field.queryset.__class__, RestrictedQuerySet):
  173. field.queryset = field.queryset.restrict(user, action)
  174. def parse_csv(reader):
  175. """
  176. Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
  177. if the records are formatted incorrectly. Return headers and records as a tuple.
  178. """
  179. records = []
  180. headers = {}
  181. # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
  182. # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
  183. # `site.slug` header, to indicate the related site is being referenced by its slug.
  184. for header in next(reader):
  185. header = header.strip()
  186. if '.' in header:
  187. field, to_field = header.split('.', 1)
  188. if field in headers:
  189. raise forms.ValidationError(f'Duplicate or conflicting column header for "{field}"')
  190. headers[field] = to_field
  191. else:
  192. if header in headers:
  193. raise forms.ValidationError(f'Duplicate or conflicting column header for "{header}"')
  194. headers[header] = None
  195. # Parse CSV rows into a list of dictionaries mapped from the column headers.
  196. for i, row in enumerate(reader, start=1):
  197. if len(row) != len(headers):
  198. raise forms.ValidationError(
  199. f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
  200. )
  201. row = [col.strip() for col in row]
  202. record = dict(zip(headers.keys(), row))
  203. records.append(record)
  204. return headers, records
  205. def validate_csv(headers, fields, required_fields):
  206. """
  207. Validate that parsed csv data conforms to the object's available fields. Raise validation errors
  208. if parsed csv data contains invalid headers or does not contain required headers.
  209. """
  210. # Validate provided column headers
  211. is_update = False
  212. for field, to_field in headers.items():
  213. if field == "id":
  214. is_update = True
  215. continue
  216. if field not in fields:
  217. raise forms.ValidationError(f'Unexpected column header "{field}" found.')
  218. if to_field and not hasattr(fields[field], 'to_field_name'):
  219. raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
  220. if to_field and not hasattr(fields[field].queryset.model, to_field):
  221. raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
  222. # Validate required fields (if not an update)
  223. if not is_update:
  224. for f in required_fields:
  225. if f not in headers:
  226. raise forms.ValidationError(f'Required column header "{f}" not found.')