utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. import datetime
  2. import json
  3. from collections import OrderedDict
  4. from decimal import Decimal
  5. from itertools import count, groupby
  6. from django.core.serializers import serialize
  7. from django.db.models import Count, OuterRef, Subquery
  8. from django.db.models.functions import Coalesce
  9. from django.http import QueryDict
  10. from jinja2.sandbox import SandboxedEnvironment
  11. from mptt.models import MPTTModel
  12. from dcim.choices import CableLengthUnitChoices
  13. from extras.plugins import PluginConfig
  14. from extras.utils import is_taggable
  15. from utilities.constants import HTTP_REQUEST_META_SAFE_COPY
  16. def get_viewname(model, action=None, rest_api=False):
  17. """
  18. Return the view name for the given model and action, if valid.
  19. :param model: The model or instance to which the view applies
  20. :param action: A string indicating the desired action (if any); e.g. "add" or "list"
  21. :param rest_api: A boolean indicating whether this is a REST API view
  22. """
  23. is_plugin = isinstance(model._meta.app_config, PluginConfig)
  24. app_label = model._meta.app_label
  25. model_name = model._meta.model_name
  26. if rest_api:
  27. if is_plugin:
  28. viewname = f'plugins-api:{app_label}-api:{model_name}'
  29. else:
  30. viewname = f'{app_label}-api:{model_name}'
  31. # Append the action, if any
  32. if action:
  33. viewname = f'{viewname}-{action}'
  34. else:
  35. viewname = f'{app_label}:{model_name}'
  36. # Prepend the plugins namespace if this is a plugin model
  37. if is_plugin:
  38. viewname = f'plugins:{viewname}'
  39. # Append the action, if any
  40. if action:
  41. viewname = f'{viewname}_{action}'
  42. return viewname
  43. def csv_format(data):
  44. """
  45. Encapsulate any data which contains a comma within double quotes.
  46. """
  47. csv = []
  48. for value in data:
  49. # Represent None or False with empty string
  50. if value is None or value is False:
  51. csv.append('')
  52. continue
  53. # Convert dates to ISO format
  54. if isinstance(value, (datetime.date, datetime.datetime)):
  55. value = value.isoformat()
  56. # Force conversion to string first so we can check for any commas
  57. if not isinstance(value, str):
  58. value = '{}'.format(value)
  59. # Double-quote the value if it contains a comma or line break
  60. if ',' in value or '\n' in value:
  61. value = value.replace('"', '""') # Escape double-quotes
  62. csv.append('"{}"'.format(value))
  63. else:
  64. csv.append('{}'.format(value))
  65. return ','.join(csv)
  66. def foreground_color(bg_color, dark='000000', light='ffffff'):
  67. """
  68. Return the ideal foreground color (dark or light) for a given background color in hexadecimal RGB format.
  69. :param dark: RBG color code for dark text
  70. :param light: RBG color code for light text
  71. """
  72. THRESHOLD = 150
  73. bg_color = bg_color.strip('#')
  74. r, g, b = [int(bg_color[c:c + 2], 16) for c in (0, 2, 4)]
  75. if r * 0.299 + g * 0.587 + b * 0.114 > THRESHOLD:
  76. return dark
  77. else:
  78. return light
  79. def dynamic_import(name):
  80. """
  81. Dynamically import a class from an absolute path string
  82. """
  83. components = name.split('.')
  84. mod = __import__(components[0])
  85. for comp in components[1:]:
  86. mod = getattr(mod, comp)
  87. return mod
  88. def count_related(model, field):
  89. """
  90. Return a Subquery suitable for annotating a child object count.
  91. """
  92. subquery = Subquery(
  93. model.objects.filter(
  94. **{field: OuterRef('pk')}
  95. ).order_by().values(
  96. field
  97. ).annotate(
  98. c=Count('*')
  99. ).values('c')
  100. )
  101. return Coalesce(subquery, 0)
  102. def serialize_object(obj, extra=None):
  103. """
  104. Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
  105. change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
  106. can be provided to exclude them from the returned dictionary. Private fields (prefaced with an underscore) are
  107. implicitly excluded.
  108. """
  109. json_str = serialize('json', [obj])
  110. print(json_str)
  111. data = json.loads(json_str)[0]['fields']
  112. # Exclude any MPTTModel fields
  113. if issubclass(obj.__class__, MPTTModel):
  114. for field in ['level', 'lft', 'rght', 'tree_id']:
  115. data.pop(field)
  116. # Include custom_field_data as "custom_fields"
  117. if hasattr(obj, 'custom_field_data'):
  118. data['custom_fields'] = data.pop('custom_field_data')
  119. # Include any tags. Check for tags cached on the instance; fall back to using the manager.
  120. if is_taggable(obj):
  121. tags = getattr(obj, '_tags', None) or obj.tags.all()
  122. data['tags'] = [tag.name for tag in tags]
  123. # Append any extra data
  124. if extra is not None:
  125. data.update(extra)
  126. # Copy keys to list to avoid 'dictionary changed size during iteration' exception
  127. for key in list(data):
  128. # Private fields shouldn't be logged in the object change
  129. if isinstance(key, str) and key.startswith('_'):
  130. data.pop(key)
  131. return data
  132. def dict_to_filter_params(d, prefix=''):
  133. """
  134. Translate a dictionary of attributes to a nested set of parameters suitable for QuerySet filtering. For example:
  135. {
  136. "name": "Foo",
  137. "rack": {
  138. "facility_id": "R101"
  139. }
  140. }
  141. Becomes:
  142. {
  143. "name": "Foo",
  144. "rack__facility_id": "R101"
  145. }
  146. And can be employed as filter parameters:
  147. Device.objects.filter(**dict_to_filter(attrs_dict))
  148. """
  149. params = {}
  150. for key, val in d.items():
  151. k = prefix + key
  152. if isinstance(val, dict):
  153. params.update(dict_to_filter_params(val, k + '__'))
  154. else:
  155. params[k] = val
  156. return params
  157. def normalize_querydict(querydict):
  158. """
  159. Convert a QueryDict to a normal, mutable dictionary, preserving list values. For example,
  160. QueryDict('foo=1&bar=2&bar=3&baz=')
  161. becomes:
  162. {'foo': '1', 'bar': ['2', '3'], 'baz': ''}
  163. This function is necessary because QueryDict does not provide any built-in mechanism which preserves multiple
  164. values.
  165. """
  166. return {
  167. k: v if len(v) > 1 else v[0] for k, v in querydict.lists()
  168. }
  169. def deepmerge(original, new):
  170. """
  171. Deep merge two dictionaries (new into original) and return a new dict
  172. """
  173. merged = OrderedDict(original)
  174. for key, val in new.items():
  175. if key in original and isinstance(original[key], dict) and val and isinstance(val, dict):
  176. merged[key] = deepmerge(original[key], val)
  177. else:
  178. merged[key] = val
  179. return merged
  180. def to_meters(length, unit):
  181. """
  182. Convert the given length to meters.
  183. """
  184. try:
  185. if length < 0:
  186. raise ValueError("Length must be a positive number")
  187. except TypeError:
  188. raise TypeError(f"Invalid value '{length}' for length (must be a number)")
  189. valid_units = CableLengthUnitChoices.values()
  190. if unit not in valid_units:
  191. raise ValueError(f"Unknown unit {unit}. Must be one of the following: {', '.join(valid_units)}")
  192. if unit == CableLengthUnitChoices.UNIT_KILOMETER:
  193. return length * 1000
  194. if unit == CableLengthUnitChoices.UNIT_METER:
  195. return length
  196. if unit == CableLengthUnitChoices.UNIT_CENTIMETER:
  197. return length / 100
  198. if unit == CableLengthUnitChoices.UNIT_MILE:
  199. return length * Decimal(1609.344)
  200. if unit == CableLengthUnitChoices.UNIT_FOOT:
  201. return length * Decimal(0.3048)
  202. if unit == CableLengthUnitChoices.UNIT_INCH:
  203. return length * Decimal(0.3048) * 12
  204. raise ValueError(f"Unknown unit {unit}. Must be 'km', 'm', 'cm', 'mi', 'ft', or 'in'.")
  205. def render_jinja2(template_code, context):
  206. """
  207. Render a Jinja2 template with the provided context. Return the rendered content.
  208. """
  209. return SandboxedEnvironment().from_string(source=template_code).render(**context)
  210. def prepare_cloned_fields(instance):
  211. """
  212. Compile an object's `clone_fields` list into a string of URL query parameters. Tags are automatically cloned where
  213. applicable.
  214. """
  215. params = []
  216. for field_name in getattr(instance, 'clone_fields', []):
  217. field = instance._meta.get_field(field_name)
  218. field_value = field.value_from_object(instance)
  219. # Pass False as null for boolean fields
  220. if field_value is False:
  221. params.append((field_name, ''))
  222. # Omit empty values
  223. elif field_value not in (None, ''):
  224. params.append((field_name, field_value))
  225. # Copy tags
  226. if is_taggable(instance):
  227. for tag in instance.tags.all():
  228. params.append(('tags', tag.pk))
  229. # Return a QueryDict with the parameters
  230. return QueryDict('&'.join([f'{k}={v}' for k, v in params]), mutable=True)
  231. def shallow_compare_dict(source_dict, destination_dict, exclude=None):
  232. """
  233. Return a new dictionary of the different keys. The values of `destination_dict` are returned. Only the equality of
  234. the first layer of keys/values is checked. `exclude` is a list or tuple of keys to be ignored.
  235. """
  236. difference = {}
  237. for key in destination_dict:
  238. if source_dict.get(key) != destination_dict[key]:
  239. if isinstance(exclude, (list, tuple)) and key in exclude:
  240. continue
  241. difference[key] = destination_dict[key]
  242. return difference
  243. def flatten_dict(d, prefix='', separator='.'):
  244. """
  245. Flatten netsted dictionaries into a single level by joining key names with a separator.
  246. :param d: The dictionary to be flattened
  247. :param prefix: Initial prefix (if any)
  248. :param separator: The character to use when concatenating key names
  249. """
  250. ret = {}
  251. for k, v in d.items():
  252. key = separator.join([prefix, k]) if prefix else k
  253. if type(v) is dict:
  254. ret.update(flatten_dict(v, prefix=key, separator=separator))
  255. else:
  256. ret[key] = v
  257. return ret
  258. def array_to_string(array):
  259. """
  260. Generate an efficient, human-friendly string from a set of integers. Intended for use with ArrayField.
  261. For example:
  262. [0, 1, 2, 10, 14, 15, 16] => "0-2, 10, 14-16"
  263. """
  264. group = (list(x) for _, x in groupby(sorted(array), lambda x, c=count(): next(c) - x))
  265. return ', '.join('-'.join(map(str, (g[0], g[-1])[:len(g)])) for g in group)
  266. def content_type_name(ct):
  267. """
  268. Return a human-friendly ContentType name (e.g. "DCIM > Site").
  269. """
  270. try:
  271. meta = ct.model_class()._meta
  272. return f'{meta.app_config.verbose_name} > {meta.verbose_name}'
  273. except AttributeError:
  274. # Model no longer exists
  275. return f'{ct.app_label} > {ct.model}'
  276. def content_type_identifier(ct):
  277. """
  278. Return a "raw" ContentType identifier string suitable for bulk import/export (e.g. "dcim.site").
  279. """
  280. return f'{ct.app_label}.{ct.model}'
  281. #
  282. # Fake request object
  283. #
  284. class NetBoxFakeRequest:
  285. """
  286. A fake request object which is explicitly defined at the module level so it is able to be pickled. It simply
  287. takes what is passed to it as kwargs on init and sets them as instance variables.
  288. """
  289. def __init__(self, _dict):
  290. self.__dict__ = _dict
  291. def copy_safe_request(request):
  292. """
  293. Copy selected attributes from a request object into a new fake request object. This is needed in places where
  294. thread safe pickling of the useful request data is needed.
  295. """
  296. meta = {
  297. k: request.META[k]
  298. for k in HTTP_REQUEST_META_SAFE_COPY
  299. if k in request.META and isinstance(request.META[k], str)
  300. }
  301. return NetBoxFakeRequest({
  302. 'META': meta,
  303. 'POST': request.POST,
  304. 'GET': request.GET,
  305. 'FILES': request.FILES,
  306. 'user': request.user,
  307. 'path': request.path,
  308. 'id': getattr(request, 'id', None), # UUID assigned by middleware
  309. })