filtersets.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from datetime import datetime, timezone
  2. from itertools import chain
  3. import django_filters
  4. from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
  5. from django.contrib.contenttypes.models import ContentType
  6. from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
  7. from django.utils.module_loading import import_string
  8. from mptt.models import MPTTModel
  9. from taggit.managers import TaggableManager
  10. from extras.filters import TagFilter
  11. from utilities.filters import MultiValueContentTypeFilter, TreeNodeMultipleChoiceFilter
  12. __all__ = (
  13. 'BaseFilterSetTests',
  14. 'ChangeLoggedFilterSetTests',
  15. )
  16. EXEMPT_MODEL_FIELDS = (
  17. 'comments',
  18. 'custom_field_data',
  19. 'level', # MPTT
  20. 'lft', # MPTT
  21. 'rght', # MPTT
  22. 'tree_id', # MPTT
  23. )
  24. class BaseFilterSetTests:
  25. queryset = None
  26. filterset = None
  27. ignore_fields = tuple()
  28. filter_name_map = {}
  29. def get_m2m_filter_name(self, field):
  30. """
  31. Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
  32. cases may override this method to prescribe deviations for specific fields.
  33. """
  34. related_model_name = field.related_model._meta.verbose_name
  35. return related_model_name.lower().replace(' ', '_')
  36. def get_filters_for_model_field(self, field):
  37. """
  38. Given a model field, return an iterable of (name, class) for each filter that should be defined on
  39. the model's FilterSet class. If the appropriate filter class cannot be determined, it will be None.
  40. filter_name_map provides a mechanism for developers to provide an actual field name for the
  41. filter that is being resolved, given the field's actual name.
  42. """
  43. # If an alias is not present in filter_name_map, then use field.name
  44. filter_name = self.filter_name_map.get(field.name, field.name)
  45. # ForeignKey & OneToOneField
  46. if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
  47. # Relationships to ContentType (used as part of a GFK) do not need a filter
  48. if field.related_model is ContentType:
  49. return [(None, None)]
  50. # ForeignKey to an MPTT-enabled model
  51. if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
  52. return [(f'{filter_name}_id', TreeNodeMultipleChoiceFilter)]
  53. return [(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter)]
  54. # Many-to-many relationships (forward & backward)
  55. if type(field) in (ManyToManyField, ManyToManyRel):
  56. filter_name = self.get_m2m_filter_name(field)
  57. filter_name = self.filter_name_map.get(filter_name, filter_name)
  58. # ManyToManyFields to ContentType need two filters: 'app.model' & PK
  59. if field.related_model is ContentType:
  60. # Standardize on object_type for filter name even though it's technically a ContentType
  61. filter_name = 'object_type'
  62. return [
  63. (filter_name, MultiValueContentTypeFilter),
  64. (f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter),
  65. ]
  66. return [(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter)]
  67. # Tag manager
  68. if type(field) is TaggableManager:
  69. return [('tag', TagFilter)]
  70. # Unable to determine the correct filter class
  71. return [(filter_name, None)]
  72. def test_id(self):
  73. """
  74. Test filtering for two PKs from a set of >2 objects.
  75. """
  76. params = {'id': self.queryset.values_list('pk', flat=True)[:2]}
  77. self.assertGreater(self.queryset.count(), 2)
  78. self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
  79. def test_missing_filters(self):
  80. """
  81. Check for any model fields which do not have the required filter(s) defined.
  82. """
  83. app_label = self.__class__.__module__.split('.')[0]
  84. model = self.queryset.model
  85. model_name = model.__name__
  86. # Import the FilterSet class & sanity check it
  87. filterset = import_string(f'{app_label}.filtersets.{model_name}FilterSet')
  88. self.assertEqual(model, filterset.Meta.model, "FilterSet model does not match!")
  89. filters = filterset.get_filters()
  90. # Check for missing filters
  91. for model_field in model._meta.get_fields():
  92. # Skip private fields
  93. if model_field.name.startswith('_'):
  94. continue
  95. # Skip ignored fields
  96. if model_field.name in chain(self.ignore_fields, EXEMPT_MODEL_FIELDS):
  97. continue
  98. # Skip reverse ForeignKey relationships
  99. if type(model_field) is ManyToOneRel:
  100. continue
  101. # Skip generic relationships
  102. if type(model_field) in (GenericForeignKey, GenericRelation):
  103. continue
  104. for filter_name, filter_class in self.get_filters_for_model_field(model_field):
  105. if filter_name is None:
  106. # Field is exempt
  107. continue
  108. # Check that the filter is defined
  109. self.assertIn(
  110. filter_name,
  111. filters.keys(),
  112. f'No filter defined for {filter_name} ({model_field.name})!'
  113. )
  114. # Check that the filter class is correct
  115. filter = filters[filter_name]
  116. if filter_class is not None:
  117. self.assertIsInstance(
  118. filter,
  119. filter_class,
  120. f"Invalid filter class {type(filter)} for {filter_name} (should be {filter_class})!"
  121. )
  122. class ChangeLoggedFilterSetTests(BaseFilterSetTests):
  123. def test_created(self):
  124. pk_list = self.queryset.values_list('pk', flat=True)[:2]
  125. self.queryset.filter(pk__in=pk_list).update(created=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc))
  126. params = {'created': ['2021-01-01T00:00:00']}
  127. self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
  128. def test_last_updated(self):
  129. pk_list = self.queryset.values_list('pk', flat=True)[:2]
  130. self.queryset.filter(pk__in=pk_list).update(last_updated=datetime(2021, 1, 2, 0, 0, 0, tzinfo=timezone.utc))
  131. params = {'last_updated': ['2021-01-02T00:00:00']}
  132. self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)