base.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import json
  2. from django.contrib.auth import get_user_model
  3. from django.contrib.contenttypes.models import ContentType
  4. from django.contrib.postgres.fields import ArrayField
  5. from django.core.exceptions import FieldDoesNotExist
  6. from django.db.models import ManyToManyField, JSONField
  7. from django.forms.models import model_to_dict
  8. from django.test import Client, TestCase as _TestCase
  9. from netaddr import IPNetwork
  10. from taggit.managers import TaggableManager
  11. from core.models import ObjectType
  12. from users.models import ObjectPermission
  13. from utilities.permissions import resolve_permission_ct
  14. from utilities.utils import content_type_identifier
  15. from .utils import extract_form_failures
  16. __all__ = (
  17. 'ModelTestCase',
  18. 'TestCase',
  19. )
  20. class TestCase(_TestCase):
  21. user_permissions = ()
  22. def setUp(self):
  23. # Create the test user and assign permissions
  24. self.user = get_user_model().objects.create_user(username='testuser')
  25. self.add_permissions(*self.user_permissions)
  26. # Initialize the test client
  27. self.client = Client()
  28. self.client.force_login(self.user)
  29. #
  30. # Permissions management
  31. #
  32. def add_permissions(self, *names):
  33. """
  34. Assign a set of permissions to the test user. Accepts permission names in the form <app>.<action>_<model>.
  35. """
  36. for name in names:
  37. ct, action = resolve_permission_ct(name)
  38. obj_perm = ObjectPermission(name=name, actions=[action])
  39. obj_perm.save()
  40. obj_perm.users.add(self.user)
  41. obj_perm.object_types.add(ct)
  42. #
  43. # Custom assertions
  44. #
  45. def assertHttpStatus(self, response, expected_status):
  46. """
  47. TestCase method. Provide more detail in the event of an unexpected HTTP response.
  48. """
  49. err_message = None
  50. # Construct an error message only if we know the test is going to fail
  51. if response.status_code != expected_status:
  52. if hasattr(response, 'data'):
  53. # REST API response; pass the response data through directly
  54. err = response.data
  55. else:
  56. # Attempt to extract form validation errors from the response HTML
  57. form_errors = extract_form_failures(response.content)
  58. err = form_errors or response.content or 'No data'
  59. err_message = f"Expected HTTP status {expected_status}; received {response.status_code}: {err}"
  60. self.assertEqual(response.status_code, expected_status, err_message)
  61. class ModelTestCase(TestCase):
  62. """
  63. Parent class for TestCases which deal with models.
  64. """
  65. model = None
  66. def _get_queryset(self):
  67. """
  68. Return a base queryset suitable for use in test methods.
  69. """
  70. return self.model.objects.all()
  71. def prepare_instance(self, instance):
  72. """
  73. Test cases can override this method to perform any necessary manipulation of an instance prior to its evaluation
  74. against test data. For example, it can be used to decrypt a Secret's plaintext attribute.
  75. """
  76. return instance
  77. def model_to_dict(self, instance, fields, api=False):
  78. """
  79. Return a dictionary representation of an instance.
  80. """
  81. # Prepare the instance and call Django's model_to_dict() to extract all fields
  82. model_dict = model_to_dict(self.prepare_instance(instance), fields=fields)
  83. # Map any additional (non-field) instance attributes that were specified
  84. for attr in fields:
  85. if hasattr(instance, attr) and attr not in model_dict:
  86. model_dict[attr] = getattr(instance, attr)
  87. for key, value in list(model_dict.items()):
  88. try:
  89. field = instance._meta.get_field(key)
  90. except FieldDoesNotExist:
  91. # Attribute is not a model field
  92. continue
  93. # Handle ManyToManyFields
  94. if value and type(field) in (ManyToManyField, TaggableManager):
  95. if field.related_model in (ContentType, ObjectType) and api:
  96. model_dict[key] = sorted([content_type_identifier(ct) for ct in value])
  97. else:
  98. model_dict[key] = sorted([obj.pk for obj in value])
  99. elif api:
  100. # Replace ContentType numeric IDs with <app_label>.<model>
  101. if type(getattr(instance, key)) in (ContentType, ObjectType):
  102. ct = ObjectType.objects.get(pk=value)
  103. model_dict[key] = content_type_identifier(ct)
  104. # Convert IPNetwork instances to strings
  105. elif type(value) is IPNetwork:
  106. model_dict[key] = str(value)
  107. else:
  108. field = instance._meta.get_field(key)
  109. # Convert ArrayFields to CSV strings
  110. if type(field) is ArrayField:
  111. if type(field.base_field) is ArrayField:
  112. # Handle nested arrays (e.g. choice sets)
  113. model_dict[key] = '\n'.join([f'{k},{v}' for k, v in value])
  114. else:
  115. model_dict[key] = ','.join([str(v) for v in value])
  116. # JSON
  117. if type(field) is JSONField and value is not None:
  118. model_dict[key] = json.dumps(value)
  119. return model_dict
  120. #
  121. # Custom assertions
  122. #
  123. def assertInstanceEqual(self, instance, data, exclude=None, api=False):
  124. """
  125. Compare a model instance to a dictionary, checking that its attribute values match those specified
  126. in the dictionary.
  127. :param instance: Python object instance
  128. :param data: Dictionary of test data used to define the instance
  129. :param exclude: List of fields to exclude from comparison (e.g. passwords, which get hashed)
  130. :param api: Set to True is the data is a JSON representation of the instance
  131. """
  132. if exclude is None:
  133. exclude = []
  134. fields = [k for k in data.keys() if k not in exclude]
  135. model_dict = self.model_to_dict(instance, fields=fields, api=api)
  136. # Omit any dictionary keys which are not instance attributes or have been excluded
  137. relevant_data = {
  138. k: v for k, v in data.items() if hasattr(instance, k) and k not in exclude
  139. }
  140. self.assertDictEqual(model_dict, relevant_data)