utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. import datetime
  2. import decimal
  3. import json
  4. import re
  5. from decimal import Decimal
  6. from itertools import count, groupby
  7. from urllib.parse import urlencode
  8. import nh3
  9. from django.contrib.contenttypes.models import ContentType
  10. from django.core import serializers
  11. from django.db.models import Count, ManyToOneRel, OuterRef, Subquery
  12. from django.db.models.functions import Coalesce
  13. from django.http import QueryDict
  14. from django.utils import timezone
  15. from django.utils.datastructures import MultiValueDict
  16. from django.utils.html import escape
  17. from django.utils.timezone import localtime
  18. from django.utils.translation import gettext as _
  19. from jinja2.sandbox import SandboxedEnvironment
  20. from mptt.models import MPTTModel
  21. from dcim.choices import CableLengthUnitChoices, WeightUnitChoices
  22. from extras.utils import is_taggable
  23. from netbox.config import get_config
  24. from netbox.plugins import PluginConfig
  25. from utilities.constants import HTTP_REQUEST_META_SAFE_COPY
  26. from .constants import HTML_ALLOWED_ATTRIBUTES, HTML_ALLOWED_TAGS
  27. def title(value):
  28. """
  29. Improved implementation of str.title(); retains all existing uppercase letters.
  30. """
  31. return ' '.join([w[0].upper() + w[1:] for w in str(value).split()])
  32. def get_viewname(model, action=None, rest_api=False):
  33. """
  34. Return the view name for the given model and action, if valid.
  35. :param model: The model or instance to which the view applies
  36. :param action: A string indicating the desired action (if any); e.g. "add" or "list"
  37. :param rest_api: A boolean indicating whether this is a REST API view
  38. """
  39. is_plugin = isinstance(model._meta.app_config, PluginConfig)
  40. app_label = model._meta.app_label
  41. model_name = model._meta.model_name
  42. if rest_api:
  43. viewname = f'{app_label}-api:{model_name}'
  44. if is_plugin:
  45. viewname = f'plugins-api:{viewname}'
  46. if action:
  47. viewname = f'{viewname}-{action}'
  48. else:
  49. viewname = f'{app_label}:{model_name}'
  50. if is_plugin:
  51. viewname = f'plugins:{viewname}'
  52. if action:
  53. viewname = f'{viewname}_{action}'
  54. return viewname
  55. def csv_format(data):
  56. """
  57. Encapsulate any data which contains a comma within double quotes.
  58. """
  59. csv = []
  60. for value in data:
  61. # Represent None or False with empty string
  62. if value is None or value is False:
  63. csv.append('')
  64. continue
  65. # Convert dates to ISO format
  66. if isinstance(value, (datetime.date, datetime.datetime)):
  67. value = value.isoformat()
  68. # Force conversion to string first so we can check for any commas
  69. if not isinstance(value, str):
  70. value = '{}'.format(value)
  71. # Double-quote the value if it contains a comma or line break
  72. if ',' in value or '\n' in value:
  73. value = value.replace('"', '""') # Escape double-quotes
  74. csv.append('"{}"'.format(value))
  75. else:
  76. csv.append('{}'.format(value))
  77. return ','.join(csv)
  78. def foreground_color(bg_color, dark='000000', light='ffffff'):
  79. """
  80. Return the ideal foreground color (dark or light) for a given background color in hexadecimal RGB format.
  81. :param dark: RBG color code for dark text
  82. :param light: RBG color code for light text
  83. """
  84. THRESHOLD = 150
  85. bg_color = bg_color.strip('#')
  86. r, g, b = [int(bg_color[c:c + 2], 16) for c in (0, 2, 4)]
  87. if r * 0.299 + g * 0.587 + b * 0.114 > THRESHOLD:
  88. return dark
  89. else:
  90. return light
  91. def dynamic_import(name):
  92. """
  93. Dynamically import a class from an absolute path string
  94. """
  95. components = name.split('.')
  96. mod = __import__(components[0])
  97. for comp in components[1:]:
  98. mod = getattr(mod, comp)
  99. return mod
  100. def count_related(model, field):
  101. """
  102. Return a Subquery suitable for annotating a child object count.
  103. """
  104. subquery = Subquery(
  105. model.objects.filter(
  106. **{field: OuterRef('pk')}
  107. ).order_by().values(
  108. field
  109. ).annotate(
  110. c=Count('*')
  111. ).values('c')
  112. )
  113. return Coalesce(subquery, 0)
  114. def serialize_object(obj, resolve_tags=True, extra=None, exclude=None):
  115. """
  116. Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
  117. change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
  118. can be provided to exclude them from the returned dictionary. Private fields (prefaced with an underscore) are
  119. implicitly excluded.
  120. Args:
  121. obj: The object to serialize
  122. resolve_tags: If true, any assigned tags will be represented by their names
  123. extra: Any additional data to include in the serialized output. Keys provided in this mapping will
  124. override object attributes.
  125. exclude: An iterable of attributes to exclude from the serialized output
  126. """
  127. json_str = serializers.serialize('json', [obj])
  128. data = json.loads(json_str)[0]['fields']
  129. exclude = exclude or []
  130. # Exclude any MPTTModel fields
  131. if issubclass(obj.__class__, MPTTModel):
  132. for field in ['level', 'lft', 'rght', 'tree_id']:
  133. data.pop(field)
  134. # Include custom_field_data as "custom_fields"
  135. if hasattr(obj, 'custom_field_data'):
  136. data['custom_fields'] = data.pop('custom_field_data')
  137. # Resolve any assigned tags to their names. Check for tags cached on the instance;
  138. # fall back to using the manager.
  139. if resolve_tags and is_taggable(obj):
  140. tags = getattr(obj, '_tags', None) or obj.tags.all()
  141. data['tags'] = sorted([tag.name for tag in tags])
  142. # Skip excluded and private (prefixes with an underscore) attributes
  143. for key in list(data.keys()):
  144. if key in exclude or (isinstance(key, str) and key.startswith('_')):
  145. data.pop(key)
  146. # Append any extra data
  147. if extra is not None:
  148. data.update(extra)
  149. return data
  150. def deserialize_object(model, fields, pk=None):
  151. """
  152. Instantiate an object from the given model and field data. Functions as
  153. the complement to serialize_object().
  154. """
  155. content_type = ContentType.objects.get_for_model(model)
  156. if 'custom_fields' in fields:
  157. fields['custom_field_data'] = fields.pop('custom_fields')
  158. data = {
  159. 'model': '.'.join(content_type.natural_key()),
  160. 'pk': pk,
  161. 'fields': fields,
  162. }
  163. instance = list(serializers.deserialize('python', [data]))[0]
  164. return instance
  165. def dict_to_filter_params(d, prefix=''):
  166. """
  167. Translate a dictionary of attributes to a nested set of parameters suitable for QuerySet filtering. For example:
  168. {
  169. "name": "Foo",
  170. "rack": {
  171. "facility_id": "R101"
  172. }
  173. }
  174. Becomes:
  175. {
  176. "name": "Foo",
  177. "rack__facility_id": "R101"
  178. }
  179. And can be employed as filter parameters:
  180. Device.objects.filter(**dict_to_filter(attrs_dict))
  181. """
  182. params = {}
  183. for key, val in d.items():
  184. k = prefix + key
  185. if isinstance(val, dict):
  186. params.update(dict_to_filter_params(val, k + '__'))
  187. else:
  188. params[k] = val
  189. return params
  190. def dict_to_querydict(d, mutable=True):
  191. """
  192. Create a QueryDict instance from a regular Python dictionary.
  193. """
  194. qd = QueryDict(mutable=True)
  195. for k, v in d.items():
  196. item = MultiValueDict({k: v}) if isinstance(v, (list, tuple, set)) else {k: v}
  197. qd.update(item)
  198. if not mutable:
  199. qd._mutable = False
  200. return qd
  201. def normalize_querydict(querydict):
  202. """
  203. Convert a QueryDict to a normal, mutable dictionary, preserving list values. For example,
  204. QueryDict('foo=1&bar=2&bar=3&baz=')
  205. becomes:
  206. {'foo': '1', 'bar': ['2', '3'], 'baz': ''}
  207. This function is necessary because QueryDict does not provide any built-in mechanism which preserves multiple
  208. values.
  209. """
  210. return {
  211. k: v if len(v) > 1 else v[0] for k, v in querydict.lists()
  212. }
  213. def deepmerge(original, new):
  214. """
  215. Deep merge two dictionaries (new into original) and return a new dict
  216. """
  217. merged = dict(original)
  218. for key, val in new.items():
  219. if key in original and isinstance(original[key], dict) and val and isinstance(val, dict):
  220. merged[key] = deepmerge(original[key], val)
  221. else:
  222. merged[key] = val
  223. return merged
  224. def drange(start, end, step=decimal.Decimal(1)):
  225. """
  226. Decimal-compatible implementation of Python's range()
  227. """
  228. start, end, step = decimal.Decimal(start), decimal.Decimal(end), decimal.Decimal(step)
  229. if start < end:
  230. while start < end:
  231. yield start
  232. start += step
  233. else:
  234. while start > end:
  235. yield start
  236. start += step
  237. def to_meters(length, unit):
  238. """
  239. Convert the given length to meters.
  240. """
  241. try:
  242. if length < 0:
  243. raise ValueError(_("Length must be a positive number"))
  244. except TypeError:
  245. raise TypeError(_("Invalid value '{length}' for length (must be a number)").format(length=length))
  246. valid_units = CableLengthUnitChoices.values()
  247. if unit not in valid_units:
  248. raise ValueError(
  249. _("Unknown unit {unit}. Must be one of the following: {valid_units}").format(
  250. unit=unit, valid_units=', '.join(valid_units)
  251. )
  252. )
  253. if unit == CableLengthUnitChoices.UNIT_KILOMETER:
  254. return length * 1000
  255. if unit == CableLengthUnitChoices.UNIT_METER:
  256. return length
  257. if unit == CableLengthUnitChoices.UNIT_CENTIMETER:
  258. return length / 100
  259. if unit == CableLengthUnitChoices.UNIT_MILE:
  260. return length * Decimal(1609.344)
  261. if unit == CableLengthUnitChoices.UNIT_FOOT:
  262. return length * Decimal(0.3048)
  263. if unit == CableLengthUnitChoices.UNIT_INCH:
  264. return length * Decimal(0.0254)
  265. raise ValueError(_("Unknown unit {unit}. Must be 'km', 'm', 'cm', 'mi', 'ft', or 'in'.").format(unit=unit))
  266. def to_grams(weight, unit):
  267. """
  268. Convert the given weight to kilograms.
  269. """
  270. try:
  271. if weight < 0:
  272. raise ValueError(_("Weight must be a positive number"))
  273. except TypeError:
  274. raise TypeError(_("Invalid value '{weight}' for weight (must be a number)").format(weight=weight))
  275. valid_units = WeightUnitChoices.values()
  276. if unit not in valid_units:
  277. raise ValueError(
  278. _("Unknown unit {unit}. Must be one of the following: {valid_units}").format(
  279. unit=unit, valid_units=', '.join(valid_units)
  280. )
  281. )
  282. if unit == WeightUnitChoices.UNIT_KILOGRAM:
  283. return weight * 1000
  284. if unit == WeightUnitChoices.UNIT_GRAM:
  285. return weight
  286. if unit == WeightUnitChoices.UNIT_POUND:
  287. return weight * Decimal(453.592)
  288. if unit == WeightUnitChoices.UNIT_OUNCE:
  289. return weight * Decimal(28.3495)
  290. raise ValueError(_("Unknown unit {unit}. Must be 'kg', 'g', 'lb', 'oz'.").format(unit=unit))
  291. def render_jinja2(template_code, context):
  292. """
  293. Render a Jinja2 template with the provided context. Return the rendered content.
  294. """
  295. environment = SandboxedEnvironment()
  296. environment.filters.update(get_config().JINJA2_FILTERS)
  297. return environment.from_string(source=template_code).render(**context)
  298. def prepare_cloned_fields(instance):
  299. """
  300. Generate a QueryDict comprising attributes from an object's clone() method.
  301. """
  302. # Generate the clone attributes from the instance
  303. if not hasattr(instance, 'clone'):
  304. return QueryDict(mutable=True)
  305. attrs = instance.clone()
  306. # Prepare querydict parameters
  307. params = []
  308. for key, value in attrs.items():
  309. if type(value) in (list, tuple):
  310. params.extend([(key, v) for v in value])
  311. elif value not in (False, None):
  312. params.append((key, value))
  313. else:
  314. params.append((key, ''))
  315. # Return a QueryDict with the parameters
  316. return QueryDict(urlencode(params), mutable=True)
  317. def shallow_compare_dict(source_dict, destination_dict, exclude=tuple()):
  318. """
  319. Return a new dictionary of the different keys. The values of `destination_dict` are returned. Only the equality of
  320. the first layer of keys/values is checked. `exclude` is a list or tuple of keys to be ignored.
  321. """
  322. difference = {}
  323. for key, value in destination_dict.items():
  324. if key in exclude:
  325. continue
  326. if source_dict.get(key) != value:
  327. difference[key] = value
  328. return difference
  329. def flatten_dict(d, prefix='', separator='.'):
  330. """
  331. Flatten netsted dictionaries into a single level by joining key names with a separator.
  332. :param d: The dictionary to be flattened
  333. :param prefix: Initial prefix (if any)
  334. :param separator: The character to use when concatenating key names
  335. """
  336. ret = {}
  337. for k, v in d.items():
  338. key = separator.join([prefix, k]) if prefix else k
  339. if type(v) is dict:
  340. ret.update(flatten_dict(v, prefix=key, separator=separator))
  341. else:
  342. ret[key] = v
  343. return ret
  344. def array_to_ranges(array):
  345. """
  346. Convert an arbitrary array of integers to a list of consecutive values. Nonconsecutive values are returned as
  347. single-item tuples. For example:
  348. [0, 1, 2, 10, 14, 15, 16] => [(0, 2), (10,), (14, 16)]"
  349. """
  350. group = (
  351. list(x) for _, x in groupby(sorted(array), lambda x, c=count(): next(c) - x)
  352. )
  353. return [
  354. (g[0], g[-1])[:len(g)] for g in group
  355. ]
  356. def array_to_string(array):
  357. """
  358. Generate an efficient, human-friendly string from a set of integers. Intended for use with ArrayField.
  359. For example:
  360. [0, 1, 2, 10, 14, 15, 16] => "0-2, 10, 14-16"
  361. """
  362. ret = []
  363. ranges = array_to_ranges(array)
  364. for value in ranges:
  365. if len(value) == 1:
  366. ret.append(str(value[0]))
  367. else:
  368. ret.append(f'{value[0]}-{value[1]}')
  369. return ', '.join(ret)
  370. def content_type_name(ct, include_app=True):
  371. """
  372. Return a human-friendly ContentType name (e.g. "DCIM > Site").
  373. """
  374. try:
  375. meta = ct.model_class()._meta
  376. app_label = title(meta.app_config.verbose_name)
  377. model_name = title(meta.verbose_name)
  378. if include_app:
  379. return f'{app_label} > {model_name}'
  380. return model_name
  381. except AttributeError:
  382. # Model no longer exists
  383. return f'{ct.app_label} > {ct.model}'
  384. def content_type_identifier(ct):
  385. """
  386. Return a "raw" ContentType identifier string suitable for bulk import/export (e.g. "dcim.site").
  387. """
  388. return f'{ct.app_label}.{ct.model}'
  389. #
  390. # Fake request object
  391. #
  392. class NetBoxFakeRequest:
  393. """
  394. A fake request object which is explicitly defined at the module level so it is able to be pickled. It simply
  395. takes what is passed to it as kwargs on init and sets them as instance variables.
  396. """
  397. def __init__(self, _dict):
  398. self.__dict__ = _dict
  399. def copy_safe_request(request):
  400. """
  401. Copy selected attributes from a request object into a new fake request object. This is needed in places where
  402. thread safe pickling of the useful request data is needed.
  403. """
  404. meta = {
  405. k: request.META[k]
  406. for k in HTTP_REQUEST_META_SAFE_COPY
  407. if k in request.META and isinstance(request.META[k], str)
  408. }
  409. return NetBoxFakeRequest({
  410. 'META': meta,
  411. 'COOKIES': request.COOKIES,
  412. 'POST': request.POST,
  413. 'GET': request.GET,
  414. 'FILES': request.FILES,
  415. 'user': request.user,
  416. 'path': request.path,
  417. 'id': getattr(request, 'id', None), # UUID assigned by middleware
  418. })
  419. def clean_html(html, schemes):
  420. """
  421. Sanitizes HTML based on a whitelist of allowed tags and attributes.
  422. Also takes a list of allowed URI schemes.
  423. """
  424. return nh3.clean(
  425. html,
  426. tags=HTML_ALLOWED_TAGS,
  427. attributes=HTML_ALLOWED_ATTRIBUTES,
  428. url_schemes=set(schemes)
  429. )
  430. def highlight_string(value, highlight, trim_pre=None, trim_post=None, trim_placeholder='...'):
  431. """
  432. Highlight a string within a string and optionally trim the pre/post portions of the original string.
  433. Args:
  434. value: The body of text being searched against
  435. highlight: The string of compiled regex pattern to highlight in `value`
  436. trim_pre: Maximum length of pre-highlight text to include
  437. trim_post: Maximum length of post-highlight text to include
  438. trim_placeholder: String value to swap in for trimmed pre/post text
  439. """
  440. # Split value on highlight string
  441. try:
  442. if type(highlight) is re.Pattern:
  443. pre, match, post = highlight.split(value, maxsplit=1)
  444. else:
  445. highlight = re.escape(highlight)
  446. pre, match, post = re.split(fr'({highlight})', value, maxsplit=1, flags=re.IGNORECASE)
  447. except ValueError as e:
  448. # Match not found
  449. return escape(value)
  450. # Trim pre/post sections to length
  451. if trim_pre and len(pre) > trim_pre:
  452. pre = trim_placeholder + pre[-trim_pre:]
  453. if trim_post and len(post) > trim_post:
  454. post = post[:trim_post] + trim_placeholder
  455. return f'{escape(pre)}<mark>{escape(match)}</mark>{escape(post)}'
  456. def local_now():
  457. """
  458. Return the current date & time in the system timezone.
  459. """
  460. return localtime(timezone.now())
  461. def get_related_models(model, ordered=True):
  462. """
  463. Return a list of all models which have a ForeignKey to the given model and the name of the field. For example,
  464. `get_related_models(Tenant)` will return all models which have a ForeignKey relationship to Tenant.
  465. """
  466. related_models = [
  467. (field.related_model, field.remote_field.name)
  468. for field in model._meta.related_objects
  469. if type(field) is ManyToOneRel
  470. ]
  471. if ordered:
  472. return sorted(related_models, key=lambda x: x[0]._meta.verbose_name.lower())
  473. return related_models