utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import re
  2. from django import forms
  3. from django.forms.models import fields_for_model
  4. from django.utils.translation import gettext as _
  5. from utilities.choices import unpack_grouped_choices
  6. from utilities.querysets import RestrictedQuerySet
  7. from .constants import *
  8. __all__ = (
  9. 'add_blank_choice',
  10. 'expand_alphanumeric_pattern',
  11. 'expand_ipaddress_pattern',
  12. 'form_from_model',
  13. 'get_field_value',
  14. 'get_selected_values',
  15. 'parse_alphanumeric_range',
  16. 'parse_numeric_range',
  17. 'restrict_form_fields',
  18. 'parse_csv',
  19. 'validate_csv',
  20. )
  21. def parse_numeric_range(string, base=10):
  22. """
  23. Expand a numeric range (continuous or not) into a decimal or
  24. hexadecimal list, as specified by the base parameter
  25. '0-3,5' => [0, 1, 2, 3, 5]
  26. '2,8-b,d,f' => [2, 8, 9, a, b, d, f]
  27. """
  28. values = list()
  29. for dash_range in string.split(','):
  30. try:
  31. begin, end = dash_range.split('-')
  32. except ValueError:
  33. begin, end = dash_range, dash_range
  34. try:
  35. begin, end = int(begin.strip(), base=base), int(end.strip(), base=base) + 1
  36. except ValueError:
  37. raise forms.ValidationError(_('Range "{value}" is invalid.').format(value=dash_range))
  38. values.extend(range(begin, end))
  39. return sorted(set(values))
  40. def parse_alphanumeric_range(string):
  41. """
  42. Expand an alphanumeric range (continuous or not) into a list.
  43. 'a-d,f' => [a, b, c, d, f]
  44. '0-3,a-d' => [0, 1, 2, 3, a, b, c, d]
  45. """
  46. values = []
  47. for value in string.split(','):
  48. if '-' not in value:
  49. # Item is not a range
  50. values.append(value)
  51. continue
  52. # Find the range's beginning & end values
  53. try:
  54. begin, end = value.split('-')
  55. vals = begin + end
  56. # Break out of loop if there's an invalid pattern to return an error
  57. if (not (vals.isdigit() or vals.isalpha())) or (vals.isalpha() and not (vals.isupper() or vals.islower())):
  58. return []
  59. except ValueError:
  60. raise forms.ValidationError(_('Range "{value}" is invalid.').format(value=value))
  61. # Numeric range
  62. if begin.isdigit() and end.isdigit():
  63. if int(begin) >= int(end):
  64. raise forms.ValidationError(
  65. _('Invalid range: Ending value ({end}) must be greater than beginning value ({begin}).').format(
  66. begin=begin, end=end
  67. )
  68. )
  69. for n in list(range(int(begin), int(end) + 1)):
  70. values.append(n)
  71. # Alphanumeric range
  72. else:
  73. # Not a valid range (more than a single character)
  74. if not len(begin) == len(end) == 1:
  75. raise forms.ValidationError(_('Range "{value}" is invalid.').format(value=value))
  76. if ord(begin) >= ord(end):
  77. raise forms.ValidationError(_('Range "{value}" is invalid.').format(value=value))
  78. for n in list(range(ord(begin), ord(end) + 1)):
  79. values.append(chr(n))
  80. return values
  81. def expand_alphanumeric_pattern(string):
  82. """
  83. Expand an alphabetic pattern into a list of strings.
  84. """
  85. lead, pattern, remnant = re.split(ALPHANUMERIC_EXPANSION_PATTERN, string, maxsplit=1)
  86. parsed_range = parse_alphanumeric_range(pattern)
  87. for i in parsed_range:
  88. if re.search(ALPHANUMERIC_EXPANSION_PATTERN, remnant):
  89. for string in expand_alphanumeric_pattern(remnant):
  90. yield "{}{}{}".format(lead, i, string)
  91. else:
  92. yield "{}{}{}".format(lead, i, remnant)
  93. def expand_ipaddress_pattern(string, family):
  94. """
  95. Expand an IP address pattern into a list of strings. Examples:
  96. '192.0.2.[1,2,100-250]/24' => ['192.0.2.1/24', '192.0.2.2/24', '192.0.2.100/24' ... '192.0.2.250/24']
  97. '2001:db8:0:[0,fd-ff]::/64' => ['2001:db8:0:0::/64', '2001:db8:0:fd::/64', ... '2001:db8:0:ff::/64']
  98. """
  99. if family not in [4, 6]:
  100. raise Exception("Invalid IP address family: {}".format(family))
  101. if family == 4:
  102. regex = IP4_EXPANSION_PATTERN
  103. base = 10
  104. else:
  105. regex = IP6_EXPANSION_PATTERN
  106. base = 16
  107. lead, pattern, remnant = re.split(regex, string, maxsplit=1)
  108. parsed_range = parse_numeric_range(pattern, base)
  109. for i in parsed_range:
  110. if re.search(regex, remnant):
  111. for string in expand_ipaddress_pattern(remnant, family):
  112. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), string])
  113. else:
  114. yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), remnant])
  115. def get_field_value(form, field_name):
  116. """
  117. Return the current bound or initial value associated with a form field, prior to calling
  118. clean() for the form.
  119. """
  120. field = form.fields[field_name]
  121. if form.is_bound and (data := form.data.get(field_name)):
  122. if hasattr(field, 'valid_value') and field.valid_value(data):
  123. return data
  124. return form.get_initial_for_field(field, field_name)
  125. def get_selected_values(form, field_name):
  126. """
  127. Return the list of selected human-friendly values for a form field
  128. """
  129. if not hasattr(form, 'cleaned_data'):
  130. form.is_valid()
  131. filter_data = form.cleaned_data.get(field_name)
  132. field = form.fields[field_name]
  133. # Non-selection field
  134. if not hasattr(field, 'choices'):
  135. return [str(filter_data)]
  136. # Model choice field
  137. if type(field.choices) is forms.models.ModelChoiceIterator:
  138. # If this is a single-choice field, wrap its value in a list
  139. if not hasattr(filter_data, '__iter__'):
  140. values = [filter_data]
  141. else:
  142. values = filter_data
  143. else:
  144. # Static selection field
  145. choices = unpack_grouped_choices(field.choices)
  146. if type(filter_data) not in (list, tuple):
  147. filter_data = [filter_data] # Ensure filter data is iterable
  148. values = [
  149. label for value, label in choices if str(value) in filter_data or None in filter_data
  150. ]
  151. # If the field has a `null_option` attribute set and it is selected,
  152. # add it to the field's grouped choices.
  153. if getattr(field, 'null_option', None) and None in filter_data:
  154. values.remove(None)
  155. values.insert(0, field.null_option)
  156. return values
  157. def add_blank_choice(choices):
  158. """
  159. Add a blank choice to the beginning of a choices list.
  160. """
  161. return ((None, '---------'),) + tuple(choices)
  162. def form_from_model(model, fields):
  163. """
  164. Return a Form class with the specified fields derived from a model. This is useful when we need a form to be used
  165. for creating objects, but want to avoid the model's validation (e.g. for bulk create/edit functions). All fields
  166. are marked as not required.
  167. """
  168. form_fields = fields_for_model(model, fields=fields)
  169. for field in form_fields.values():
  170. field.required = False
  171. return type('FormFromModel', (forms.Form,), form_fields)
  172. def restrict_form_fields(form, user, action='view'):
  173. """
  174. Restrict all form fields which reference a RestrictedQuerySet. This ensures that users see only permitted objects
  175. as available choices.
  176. """
  177. for field in form.fields.values():
  178. if hasattr(field, 'queryset') and issubclass(field.queryset.__class__, RestrictedQuerySet):
  179. field.queryset = field.queryset.restrict(user, action)
  180. def parse_csv(reader):
  181. """
  182. Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
  183. if the records are formatted incorrectly. Return headers and records as a tuple.
  184. """
  185. records = []
  186. headers = {}
  187. # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
  188. # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
  189. # `site.slug` header, to indicate the related site is being referenced by its slug.
  190. for header in next(reader):
  191. header = header.strip()
  192. if '.' in header:
  193. field, to_field = header.split('.', 1)
  194. if field in headers:
  195. raise forms.ValidationError(_('Duplicate or conflicting column header for "{field}"').format(
  196. field=field
  197. ))
  198. headers[field] = to_field
  199. else:
  200. if header in headers:
  201. raise forms.ValidationError(_('Duplicate or conflicting column header for "{header}"').format(
  202. header=header
  203. ))
  204. headers[header] = None
  205. # Parse CSV rows into a list of dictionaries mapped from the column headers.
  206. for i, row in enumerate(reader, start=1):
  207. if len(row) != len(headers):
  208. raise forms.ValidationError(
  209. _("Row {row}: Expected {count_expected} columns but found {count_found}").format(
  210. row=i, count_expected=len(headers), count_found=len(row)
  211. )
  212. )
  213. row = [col.strip() for col in row]
  214. record = dict(zip(headers.keys(), row))
  215. records.append(record)
  216. return headers, records
  217. def validate_csv(headers, fields, required_fields):
  218. """
  219. Validate that parsed csv data conforms to the object's available fields. Raise validation errors
  220. if parsed csv data contains invalid headers or does not contain required headers.
  221. """
  222. # Validate provided column headers
  223. is_update = False
  224. for field, to_field in headers.items():
  225. if field == "id":
  226. is_update = True
  227. continue
  228. if field not in fields:
  229. raise forms.ValidationError(_('Unexpected column header "{field}" found.').format(field=field))
  230. if to_field and not hasattr(fields[field], 'to_field_name'):
  231. raise forms.ValidationError(_('Column "{field}" is not a related object; cannot use dots').format(
  232. field=field
  233. ))
  234. if to_field and not hasattr(fields[field].queryset.model, to_field):
  235. raise forms.ValidationError(_('Invalid related object attribute for column "{field}": {to_field}').format(
  236. field=field, to_field=to_field
  237. ))
  238. # Validate required fields (if not an update)
  239. if not is_update:
  240. for f in required_fields:
  241. if f not in headers:
  242. raise forms.ValidationError(_('Required column header "{header}" not found.').format(header=f))