base.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import json
  2. from django.contrib.contenttypes.models import ContentType
  3. from django.contrib.postgres.fields import ArrayField, RangeField
  4. from django.core.exceptions import FieldDoesNotExist
  5. from django.db.models import ManyToManyField, ManyToManyRel, JSONField
  6. from django.forms.models import model_to_dict
  7. from django.test import Client, TestCase as _TestCase
  8. from netaddr import IPNetwork
  9. from taggit.managers import TaggableManager
  10. from core.models import ObjectType
  11. from users.models import ObjectPermission, User
  12. from utilities.data import ranges_to_string
  13. from utilities.object_types import object_type_identifier
  14. from utilities.permissions import resolve_permission_type
  15. from .utils import DUMMY_CF_DATA, extract_form_failures
  16. __all__ = (
  17. 'ModelTestCase',
  18. 'TestCase',
  19. )
  20. class TestCase(_TestCase):
  21. user_permissions = ()
  22. def setUp(self):
  23. # Create the test user and assign permissions
  24. self.user = User.objects.create_user(username='testuser')
  25. self.add_permissions(*self.user_permissions)
  26. # Initialize the test client
  27. self.client = Client()
  28. self.client.force_login(self.user)
  29. #
  30. # Permissions management
  31. #
  32. def add_permissions(self, *names):
  33. """
  34. Assign a set of permissions to the test user. Accepts permission names in the form <app>.<action>_<model>.
  35. """
  36. for name in names:
  37. object_type, action = resolve_permission_type(name)
  38. obj_perm = ObjectPermission(name=name, actions=[action])
  39. obj_perm.save()
  40. obj_perm.users.add(self.user)
  41. obj_perm.object_types.add(object_type)
  42. #
  43. # Custom assertions
  44. #
  45. def assertHttpStatus(self, response, expected_status):
  46. """
  47. TestCase method. Provide more detail in the event of an unexpected HTTP response.
  48. """
  49. err_message = None
  50. # Construct an error message only if we know the test is going to fail
  51. if response.status_code != expected_status:
  52. if hasattr(response, 'data'):
  53. # REST API response; pass the response data through directly
  54. err = response.data
  55. else:
  56. # Attempt to extract form validation errors from the response HTML
  57. form_errors = extract_form_failures(response.content)
  58. err = form_errors or response.content or 'No data'
  59. err_message = f"Expected HTTP status {expected_status}; received {response.status_code}: {err}"
  60. self.assertEqual(response.status_code, expected_status, err_message)
  61. class ModelTestCase(TestCase):
  62. """
  63. Parent class for TestCases which deal with models.
  64. """
  65. model = None
  66. def _get_queryset(self):
  67. """
  68. Return a base queryset suitable for use in test methods.
  69. """
  70. return self.model.objects.all()
  71. def prepare_instance(self, instance):
  72. """
  73. Test cases can override this method to perform any necessary manipulation of an instance prior to its evaluation
  74. against test data. For example, it can be used to decrypt a Secret's plaintext attribute.
  75. """
  76. return instance
  77. def model_to_dict(self, instance, fields, api=False):
  78. """
  79. Return a dictionary representation of an instance.
  80. """
  81. # Prepare the instance and call Django's model_to_dict() to extract all fields
  82. model_dict = model_to_dict(self.prepare_instance(instance), fields=fields)
  83. # Map any additional (non-field) instance attributes that were specified
  84. for attr in fields:
  85. if hasattr(instance, attr) and attr not in model_dict:
  86. model_dict[attr] = getattr(instance, attr)
  87. for key, value in list(model_dict.items()):
  88. try:
  89. field = instance._meta.get_field(key)
  90. except FieldDoesNotExist:
  91. # Attribute is not a model field
  92. continue
  93. # Handle ManyToManyFields
  94. if value and type(field) in (ManyToManyField, ManyToManyRel, TaggableManager):
  95. # Resolve reverse M2M relationships
  96. if isinstance(field, ManyToManyRel):
  97. value = getattr(instance, field.related_name).all()
  98. if field.related_model in (ContentType, ObjectType) and api:
  99. model_dict[key] = sorted([object_type_identifier(ot) for ot in value])
  100. else:
  101. model_dict[key] = sorted([obj.pk for obj in value])
  102. elif api:
  103. # Replace ContentType numeric IDs with <app_label>.<model>
  104. if type(getattr(instance, key)) in (ContentType, ObjectType):
  105. object_type = ObjectType.objects.get(pk=value)
  106. model_dict[key] = object_type_identifier(object_type)
  107. # Convert IPNetwork instances to strings
  108. elif type(value) is IPNetwork:
  109. model_dict[key] = str(value)
  110. else:
  111. field = instance._meta.get_field(key)
  112. # Convert ArrayFields to CSV strings
  113. if type(field) is ArrayField:
  114. if getattr(field.base_field, 'choices', None):
  115. # Values for fields with pre-defined choices can be returned as lists
  116. model_dict[key] = value
  117. elif type(field.base_field) is ArrayField:
  118. # Handle nested arrays (e.g. choice sets)
  119. model_dict[key] = '\n'.join([f'{k},{v}' for k, v in value])
  120. elif issubclass(type(field.base_field), RangeField):
  121. # Handle arrays of numeric ranges (e.g. VLANGroup VLAN ID ranges)
  122. model_dict[key] = ranges_to_string(value)
  123. else:
  124. model_dict[key] = ','.join([str(v) for v in value])
  125. # JSON
  126. if type(field) is JSONField and value is not None:
  127. model_dict[key] = json.dumps(value)
  128. return model_dict
  129. #
  130. # Custom assertions
  131. #
  132. def assertInstanceEqual(self, instance, data, exclude=None, api=False):
  133. """
  134. Compare a model instance to a dictionary, checking that its attribute values match those specified
  135. in the dictionary.
  136. :param instance: Python object instance
  137. :param data: Dictionary of test data used to define the instance
  138. :param exclude: List of fields to exclude from comparison (e.g. passwords, which get hashed)
  139. :param api: Set to True is the data is a JSON representation of the instance
  140. """
  141. if exclude is None:
  142. exclude = []
  143. fields = [k for k in data.keys() if k not in exclude]
  144. model_dict = self.model_to_dict(instance, fields=fields, api=api)
  145. # Omit any dictionary keys which are not instance attributes or have been excluded
  146. model_data = {
  147. k: v for k, v in data.items() if hasattr(instance, k) and k not in exclude
  148. }
  149. self.assertDictEqual(model_dict, model_data)
  150. # Validate any custom field data, if present
  151. if getattr(instance, 'custom_field_data', None):
  152. self.assertDictEqual(instance.custom_field_data, DUMMY_CF_DATA)