base.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from django.contrib.auth.models import User
  2. from django.contrib.contenttypes.models import ContentType
  3. from django.contrib.postgres.fields import ArrayField
  4. from django.core.exceptions import FieldDoesNotExist
  5. from django.db.models import ManyToManyField
  6. from django.forms.models import model_to_dict
  7. from django.test import Client, TestCase as _TestCase
  8. from netaddr import IPNetwork
  9. from taggit.managers import TaggableManager
  10. from users.models import ObjectPermission
  11. from utilities.permissions import resolve_permission_ct
  12. from .utils import extract_form_failures
  13. __all__ = (
  14. 'ModelTestCase',
  15. 'TestCase',
  16. )
  17. class TestCase(_TestCase):
  18. user_permissions = ()
  19. def setUp(self):
  20. # Create the test user and assign permissions
  21. self.user = User.objects.create_user(username='testuser')
  22. self.add_permissions(*self.user_permissions)
  23. # Initialize the test client
  24. self.client = Client()
  25. self.client.force_login(self.user)
  26. #
  27. # Permissions management
  28. #
  29. def add_permissions(self, *names):
  30. """
  31. Assign a set of permissions to the test user. Accepts permission names in the form <app>.<action>_<model>.
  32. """
  33. for name in names:
  34. ct, action = resolve_permission_ct(name)
  35. obj_perm = ObjectPermission(name=name, actions=[action])
  36. obj_perm.save()
  37. obj_perm.users.add(self.user)
  38. obj_perm.object_types.add(ct)
  39. #
  40. # Custom assertions
  41. #
  42. def assertHttpStatus(self, response, expected_status):
  43. """
  44. TestCase method. Provide more detail in the event of an unexpected HTTP response.
  45. """
  46. err_message = None
  47. # Construct an error message only if we know the test is going to fail
  48. if response.status_code != expected_status:
  49. if hasattr(response, 'data'):
  50. # REST API response; pass the response data through directly
  51. err = response.data
  52. else:
  53. # Attempt to extract form validation errors from the response HTML
  54. form_errors = extract_form_failures(response.content)
  55. err = form_errors or response.content or 'No data'
  56. err_message = f"Expected HTTP status {expected_status}; received {response.status_code}: {err}"
  57. self.assertEqual(response.status_code, expected_status, err_message)
  58. class ModelTestCase(TestCase):
  59. """
  60. Parent class for TestCases which deal with models.
  61. """
  62. model = None
  63. def _get_queryset(self):
  64. """
  65. Return a base queryset suitable for use in test methods.
  66. """
  67. return self.model.objects.all()
  68. def prepare_instance(self, instance):
  69. """
  70. Test cases can override this method to perform any necessary manipulation of an instance prior to its evaluation
  71. against test data. For example, it can be used to decrypt a Secret's plaintext attribute.
  72. """
  73. return instance
  74. def model_to_dict(self, instance, fields, api=False):
  75. """
  76. Return a dictionary representation of an instance.
  77. """
  78. # Prepare the instance and call Django's model_to_dict() to extract all fields
  79. model_dict = model_to_dict(self.prepare_instance(instance), fields=fields)
  80. # Map any additional (non-field) instance attributes that were specified
  81. for attr in fields:
  82. if hasattr(instance, attr) and attr not in model_dict:
  83. model_dict[attr] = getattr(instance, attr)
  84. for key, value in list(model_dict.items()):
  85. try:
  86. field = instance._meta.get_field(key)
  87. except FieldDoesNotExist:
  88. # Attribute is not a model field
  89. continue
  90. # Handle ManyToManyFields
  91. if value and type(field) in (ManyToManyField, TaggableManager):
  92. if field.related_model is ContentType:
  93. model_dict[key] = sorted([f'{ct.app_label}.{ct.model}' for ct in value])
  94. else:
  95. model_dict[key] = sorted([obj.pk for obj in value])
  96. if api:
  97. # Replace ContentType numeric IDs with <app_label>.<model>
  98. if type(getattr(instance, key)) is ContentType:
  99. ct = ContentType.objects.get(pk=value)
  100. model_dict[key] = f'{ct.app_label}.{ct.model}'
  101. # Convert IPNetwork instances to strings
  102. elif type(value) is IPNetwork:
  103. model_dict[key] = str(value)
  104. else:
  105. # Convert ArrayFields to CSV strings
  106. if type(instance._meta.get_field(key)) is ArrayField:
  107. model_dict[key] = ','.join([str(v) for v in value])
  108. return model_dict
  109. #
  110. # Custom assertions
  111. #
  112. def assertInstanceEqual(self, instance, data, exclude=None, api=False):
  113. """
  114. Compare a model instance to a dictionary, checking that its attribute values match those specified
  115. in the dictionary.
  116. :param instance: Python object instance
  117. :param data: Dictionary of test data used to define the instance
  118. :param exclude: List of fields to exclude from comparison (e.g. passwords, which get hashed)
  119. :param api: Set to True is the data is a JSON representation of the instance
  120. """
  121. if exclude is None:
  122. exclude = []
  123. fields = [k for k in data.keys() if k not in exclude]
  124. model_dict = self.model_to_dict(instance, fields=fields, api=api)
  125. # Omit any dictionary keys which are not instance attributes or have been excluded
  126. relevant_data = {
  127. k: v for k, v in data.items() if hasattr(instance, k) and k not in exclude
  128. }
  129. self.assertDictEqual(model_dict, relevant_data)