base.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import json
  2. from django.contrib.auth import get_user_model
  3. from django.contrib.contenttypes.models import ContentType
  4. from django.contrib.postgres.fields import ArrayField
  5. from django.core.exceptions import FieldDoesNotExist
  6. from django.db.models import ManyToManyField, JSONField
  7. from django.forms.models import model_to_dict
  8. from django.test import Client, TestCase as _TestCase
  9. from netaddr import IPNetwork
  10. from taggit.managers import TaggableManager
  11. from users.models import ObjectPermission
  12. from utilities.permissions import resolve_permission_ct
  13. from utilities.utils import content_type_identifier
  14. from .utils import extract_form_failures
  15. __all__ = (
  16. 'ModelTestCase',
  17. 'TestCase',
  18. )
  19. class TestCase(_TestCase):
  20. user_permissions = ()
  21. def setUp(self):
  22. # Create the test user and assign permissions
  23. self.user = get_user_model().objects.create_user(username='testuser')
  24. self.add_permissions(*self.user_permissions)
  25. # Initialize the test client
  26. self.client = Client()
  27. self.client.force_login(self.user)
  28. #
  29. # Permissions management
  30. #
  31. def add_permissions(self, *names):
  32. """
  33. Assign a set of permissions to the test user. Accepts permission names in the form <app>.<action>_<model>.
  34. """
  35. for name in names:
  36. ct, action = resolve_permission_ct(name)
  37. obj_perm = ObjectPermission(name=name, actions=[action])
  38. obj_perm.save()
  39. obj_perm.users.add(self.user)
  40. obj_perm.object_types.add(ct)
  41. #
  42. # Custom assertions
  43. #
  44. def assertHttpStatus(self, response, expected_status):
  45. """
  46. TestCase method. Provide more detail in the event of an unexpected HTTP response.
  47. """
  48. err_message = None
  49. # Construct an error message only if we know the test is going to fail
  50. if response.status_code != expected_status:
  51. if hasattr(response, 'data'):
  52. # REST API response; pass the response data through directly
  53. err = response.data
  54. else:
  55. # Attempt to extract form validation errors from the response HTML
  56. form_errors = extract_form_failures(response.content)
  57. err = form_errors or response.content or 'No data'
  58. err_message = f"Expected HTTP status {expected_status}; received {response.status_code}: {err}"
  59. self.assertEqual(response.status_code, expected_status, err_message)
  60. class ModelTestCase(TestCase):
  61. """
  62. Parent class for TestCases which deal with models.
  63. """
  64. model = None
  65. def _get_queryset(self):
  66. """
  67. Return a base queryset suitable for use in test methods.
  68. """
  69. return self.model.objects.all()
  70. def prepare_instance(self, instance):
  71. """
  72. Test cases can override this method to perform any necessary manipulation of an instance prior to its evaluation
  73. against test data. For example, it can be used to decrypt a Secret's plaintext attribute.
  74. """
  75. return instance
  76. def model_to_dict(self, instance, fields, api=False):
  77. """
  78. Return a dictionary representation of an instance.
  79. """
  80. # Prepare the instance and call Django's model_to_dict() to extract all fields
  81. model_dict = model_to_dict(self.prepare_instance(instance), fields=fields)
  82. # Map any additional (non-field) instance attributes that were specified
  83. for attr in fields:
  84. if hasattr(instance, attr) and attr not in model_dict:
  85. model_dict[attr] = getattr(instance, attr)
  86. for key, value in list(model_dict.items()):
  87. try:
  88. field = instance._meta.get_field(key)
  89. except FieldDoesNotExist:
  90. # Attribute is not a model field
  91. continue
  92. # Handle ManyToManyFields
  93. if value and type(field) in (ManyToManyField, TaggableManager):
  94. if field.related_model is ContentType and api:
  95. model_dict[key] = sorted([content_type_identifier(ct) for ct in value])
  96. else:
  97. model_dict[key] = sorted([obj.pk for obj in value])
  98. elif api:
  99. # Replace ContentType numeric IDs with <app_label>.<model>
  100. if type(getattr(instance, key)) is ContentType:
  101. ct = ContentType.objects.get(pk=value)
  102. model_dict[key] = content_type_identifier(ct)
  103. # Convert IPNetwork instances to strings
  104. elif type(value) is IPNetwork:
  105. model_dict[key] = str(value)
  106. else:
  107. field = instance._meta.get_field(key)
  108. # Convert ArrayFields to CSV strings
  109. if type(field) is ArrayField:
  110. if type(field.base_field) is ArrayField:
  111. # Handle nested arrays (e.g. choice sets)
  112. model_dict[key] = '\n'.join([f'{k},{v}' for k, v in value])
  113. else:
  114. model_dict[key] = ','.join([str(v) for v in value])
  115. # JSON
  116. if type(field) is JSONField and value is not None:
  117. model_dict[key] = json.dumps(value)
  118. return model_dict
  119. #
  120. # Custom assertions
  121. #
  122. def assertInstanceEqual(self, instance, data, exclude=None, api=False):
  123. """
  124. Compare a model instance to a dictionary, checking that its attribute values match those specified
  125. in the dictionary.
  126. :param instance: Python object instance
  127. :param data: Dictionary of test data used to define the instance
  128. :param exclude: List of fields to exclude from comparison (e.g. passwords, which get hashed)
  129. :param api: Set to True is the data is a JSON representation of the instance
  130. """
  131. if exclude is None:
  132. exclude = []
  133. fields = [k for k in data.keys() if k not in exclude]
  134. model_dict = self.model_to_dict(instance, fields=fields, api=api)
  135. # Omit any dictionary keys which are not instance attributes or have been excluded
  136. relevant_data = {
  137. k: v for k, v in data.items() if hasattr(instance, k) and k not in exclude
  138. }
  139. self.assertDictEqual(model_dict, relevant_data)