base.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import json
  2. from contextlib import contextmanager
  3. from django.contrib.contenttypes.fields import GenericForeignKey
  4. from django.contrib.contenttypes.models import ContentType
  5. from django.contrib.postgres.fields import ArrayField, RangeField
  6. from django.core.exceptions import FieldDoesNotExist
  7. from django.db import transaction
  8. from django.db.models import JSONField, ManyToManyField, ManyToManyRel
  9. from django.forms.models import model_to_dict
  10. from django.test import Client
  11. from django.test import TestCase as _TestCase
  12. from netaddr import IPNetwork
  13. from taggit.managers import TaggableManager
  14. from core.models import ObjectType
  15. from users.models import ObjectPermission, User
  16. from utilities.data import ranges_to_string
  17. from utilities.object_types import object_type_identifier
  18. from utilities.permissions import resolve_permission_type
  19. from .utils import DUMMY_CF_DATA, extract_form_failures
  20. __all__ = (
  21. 'ModelTestCase',
  22. 'TestCase',
  23. )
  24. class TestCase(_TestCase):
  25. user_permissions = ()
  26. def setUp(self):
  27. # Create the test user and assign permissions
  28. self.user = User.objects.create_user(username='testuser')
  29. self.add_permissions(*self.user_permissions)
  30. # Initialize the test client
  31. self.client = Client()
  32. self.client.force_login(self.user)
  33. @contextmanager
  34. def cleanupSubTest(self, **params):
  35. """
  36. Context manager that wraps subTest with automatic cleanup.
  37. All database changes within the context will be rolled back.
  38. """
  39. sid = transaction.savepoint()
  40. try:
  41. with self.subTest(**params):
  42. yield
  43. finally:
  44. transaction.savepoint_rollback(sid)
  45. #
  46. # Permissions management
  47. #
  48. def add_permissions(self, *names):
  49. """
  50. Assign a set of permissions to the test user. Accepts permission names in the form <app>.<action>_<model>.
  51. """
  52. for name in names:
  53. object_type, action = resolve_permission_type(name)
  54. obj_perm = ObjectPermission(name=name, actions=[action])
  55. obj_perm.save()
  56. obj_perm.users.add(self.user)
  57. obj_perm.object_types.add(object_type)
  58. def remove_permissions(self, *names):
  59. """
  60. Remove a set of permissions from the test user. Accepts permission names in the form <app>.<action>_<model>.
  61. """
  62. for name in names:
  63. object_type, action = resolve_permission_type(name)
  64. ObjectPermission.objects.filter(
  65. actions__contains=[action], object_types=object_type, users=self.user
  66. ).delete()
  67. #
  68. # Custom assertions
  69. #
  70. def assertObjectChangeData(self, objectchange, prechange_data, postchange_data):
  71. """
  72. Assert that an ObjectChange record has the expected prechange_data and postchange_data.
  73. Pass None to assert the field is null; pass any non-None value to assert it is populated.
  74. """
  75. if prechange_data is None:
  76. self.assertIsNone(objectchange.prechange_data, "Expected prechange_data to be None")
  77. else:
  78. self.assertIsNotNone(objectchange.prechange_data, "Expected prechange_data to be populated")
  79. if postchange_data is None:
  80. self.assertIsNone(objectchange.postchange_data, "Expected postchange_data to be None")
  81. else:
  82. self.assertIsNotNone(objectchange.postchange_data, "Expected postchange_data to be populated")
  83. def assertHttpStatus(self, response, expected_status):
  84. """
  85. TestCase method. Provide more detail in the event of an unexpected HTTP response.
  86. """
  87. err_message = None
  88. # Construct an error message only if we know the test is going to fail
  89. if response.status_code != expected_status:
  90. if hasattr(response, 'data'):
  91. # REST API response; pass the response data through directly
  92. err = response.data
  93. else:
  94. # Attempt to extract form validation errors from the response HTML
  95. form_errors = extract_form_failures(response.content)
  96. err = form_errors or response.content or 'No data'
  97. err_message = f"Expected HTTP status {expected_status}; received {response.status_code}: {err}"
  98. self.assertEqual(response.status_code, expected_status, err_message)
  99. class ModelTestCase(TestCase):
  100. """
  101. Parent class for TestCases which deal with models.
  102. """
  103. model = None
  104. def _get_queryset(self):
  105. """
  106. Return a base queryset suitable for use in test methods.
  107. """
  108. return self.model.objects.all()
  109. def prepare_instance(self, instance):
  110. """
  111. Test cases can override this method to perform any necessary manipulation of an instance prior to its evaluation
  112. against test data. For example, it can be used to decrypt a Secret's plaintext attribute.
  113. """
  114. return instance
  115. def model_to_dict(self, instance, fields, api=False):
  116. """
  117. Return a dictionary representation of an instance.
  118. """
  119. # Prepare the instance and call Django's model_to_dict() to extract all fields
  120. model_dict = model_to_dict(self.prepare_instance(instance), fields=fields)
  121. # Map any additional (non-field) instance attributes that were specified
  122. for attr in fields:
  123. if hasattr(instance, attr) and attr not in model_dict:
  124. model_dict[attr] = getattr(instance, attr)
  125. for key, value in list(model_dict.items()):
  126. try:
  127. field = instance._meta.get_field(key)
  128. except FieldDoesNotExist:
  129. # Attribute is not a model field
  130. continue
  131. # Handle ManyToManyFields
  132. if value and type(field) in (ManyToManyField, ManyToManyRel, TaggableManager):
  133. # Resolve reverse M2M relationships
  134. if isinstance(field, ManyToManyRel):
  135. value = getattr(instance, field.related_name).all()
  136. if field.related_model in (ContentType, ObjectType) and api:
  137. model_dict[key] = sorted([object_type_identifier(ot) for ot in value])
  138. else:
  139. model_dict[key] = sorted([obj.pk for obj in value])
  140. # Handle GenericForeignKeys
  141. elif value and type(field) is GenericForeignKey:
  142. model_dict[key] = value.pk
  143. # Handle API output
  144. elif api:
  145. # Replace ContentType numeric IDs with <app_label>.<model>
  146. if type(getattr(instance, key)) in (ContentType, ObjectType):
  147. object_type = ObjectType.objects.get(pk=value)
  148. model_dict[key] = object_type_identifier(object_type)
  149. # Convert IPNetwork instances to strings
  150. elif type(value) is IPNetwork:
  151. model_dict[key] = str(value)
  152. # Normalize arrays of numeric ranges (e.g. VLAN IDs or port ranges).
  153. # DB uses canonical half-open [lo, hi) via NumericRange; API uses inclusive [lo, hi].
  154. # Convert to inclusive pairs for stable API comparisons.
  155. elif type(field) is ArrayField and issubclass(type(field.base_field), RangeField):
  156. model_dict[key] = [[r.lower, r.upper - 1] for r in value]
  157. else:
  158. # Convert ArrayFields to CSV strings
  159. if type(field) is ArrayField:
  160. if getattr(field.base_field, 'choices', None):
  161. # Values for fields with pre-defined choices can be returned as lists
  162. model_dict[key] = value
  163. elif type(field.base_field) is ArrayField:
  164. # Handle nested arrays (e.g. choice sets)
  165. model_dict[key] = '\n'.join([f'{k},{v}' for k, v in value])
  166. elif issubclass(type(field.base_field), RangeField):
  167. # Handle arrays of numeric ranges (e.g. VLANGroup VLAN ID ranges)
  168. model_dict[key] = ranges_to_string(value)
  169. else:
  170. model_dict[key] = ','.join([str(v) for v in value])
  171. # JSON
  172. if type(field) is JSONField and value is not None:
  173. model_dict[key] = json.dumps(value)
  174. return model_dict
  175. #
  176. # Custom assertions
  177. #
  178. def assertInstanceEqual(self, instance, data, exclude=None, api=False):
  179. """
  180. Compare a model instance to a dictionary, checking that its attribute values match those specified
  181. in the dictionary.
  182. :param instance: Python object instance
  183. :param data: Dictionary of test data used to define the instance
  184. :param exclude: List of fields to exclude from comparison (e.g. passwords, which get hashed)
  185. :param api: Set to True is the data is a JSON representation of the instance
  186. """
  187. if exclude is None:
  188. exclude = []
  189. fields = [k for k in data.keys() if k not in exclude]
  190. model_dict = self.model_to_dict(instance, fields=fields, api=api)
  191. # Omit any dictionary keys which are not instance attributes or have been excluded
  192. model_data = {
  193. k: v for k, v in data.items() if hasattr(instance, k) and k not in exclude
  194. }
  195. self.assertDictEqual(model_dict, model_data)
  196. # Validate any custom field data, if present
  197. if getattr(instance, 'custom_field_data', None):
  198. self.assertDictEqual(instance.custom_field_data, DUMMY_CF_DATA)