filtersets.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import django_filters
  2. from datetime import datetime, timezone
  3. from itertools import chain
  4. from mptt.models import MPTTModel
  5. from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
  6. from django.contrib.contenttypes.models import ContentType
  7. from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
  8. from django.utils.module_loading import import_string
  9. from taggit.managers import TaggableManager
  10. from extras.filters import TagFilter
  11. from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter
  12. from core.models import ObjectType
  13. __all__ = (
  14. 'BaseFilterSetTests',
  15. 'ChangeLoggedFilterSetTests',
  16. )
  17. EXEMPT_MODEL_FIELDS = (
  18. 'comments',
  19. 'custom_field_data',
  20. 'level', # MPTT
  21. 'lft', # MPTT
  22. 'rght', # MPTT
  23. 'tree_id', # MPTT
  24. )
  25. class BaseFilterSetTests:
  26. queryset = None
  27. filterset = None
  28. ignore_fields = tuple()
  29. def get_m2m_filter_name(self, field):
  30. """
  31. Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
  32. cases may override this method to prescribe deviations for specific fields.
  33. """
  34. related_model_name = field.related_model._meta.verbose_name
  35. return related_model_name.lower().replace(' ', '_')
  36. def get_filters_for_model_field(self, field):
  37. """
  38. Given a model field, return an iterable of (name, class) for each filter that should be defined on
  39. the model's FilterSet class. If the appropriate filter class cannot be determined, it will be None.
  40. """
  41. # ForeignKey & OneToOneField
  42. if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
  43. # Relationships to ContentType (used as part of a GFK) do not need a filter
  44. if field.related_model is ContentType:
  45. return [(None, None)]
  46. # ForeignKeys to ObjectType need two filters: 'app.model' & PK
  47. if field.related_model is ObjectType:
  48. return [
  49. (field.name, ContentTypeFilter),
  50. (f'{field.name}_id', django_filters.ModelMultipleChoiceFilter),
  51. ]
  52. # ForeignKey to an MPTT-enabled model
  53. if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
  54. return [(f'{field.name}_id', TreeNodeMultipleChoiceFilter)]
  55. return [(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter)]
  56. # Many-to-many relationships (forward & backward)
  57. elif type(field) in (ManyToManyField, ManyToManyRel):
  58. filter_name = self.get_m2m_filter_name(field)
  59. # ManyToManyFields to ObjectType need two filters: 'app.model' & PK
  60. if field.related_model is ObjectType:
  61. return [
  62. (filter_name, ContentTypeFilter),
  63. (f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter),
  64. ]
  65. return [(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter)]
  66. # Tag manager
  67. if type(field) is TaggableManager:
  68. return [('tag', TagFilter)]
  69. # Unable to determine the correct filter class
  70. return [(field.name, None)]
  71. def test_id(self):
  72. """
  73. Test filtering for two PKs from a set of >2 objects.
  74. """
  75. params = {'id': self.queryset.values_list('pk', flat=True)[:2]}
  76. self.assertGreater(self.queryset.count(), 2)
  77. self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
  78. def test_missing_filters(self):
  79. """
  80. Check for any model fields which do not have the required filter(s) defined.
  81. """
  82. app_label = self.__class__.__module__.split('.')[0]
  83. model = self.queryset.model
  84. model_name = model.__name__
  85. # Import the FilterSet class & sanity check it
  86. filterset = import_string(f'{app_label}.filtersets.{model_name}FilterSet')
  87. self.assertEqual(model, filterset.Meta.model, "FilterSet model does not match!")
  88. filters = filterset.get_filters()
  89. # Check for missing filters
  90. for model_field in model._meta.get_fields():
  91. # Skip private fields
  92. if model_field.name.startswith('_'):
  93. continue
  94. # Skip ignored fields
  95. if model_field.name in chain(self.ignore_fields, EXEMPT_MODEL_FIELDS):
  96. continue
  97. # Skip reverse ForeignKey relationships
  98. if type(model_field) is ManyToOneRel:
  99. continue
  100. # Skip generic relationships
  101. if type(model_field) in (GenericForeignKey, GenericRelation):
  102. continue
  103. for filter_name, filter_class in self.get_filters_for_model_field(model_field):
  104. if filter_name is None:
  105. # Field is exempt
  106. continue
  107. # Check that the filter is defined
  108. self.assertIn(
  109. filter_name,
  110. filters.keys(),
  111. f'No filter defined for {filter_name} ({model_field.name})!'
  112. )
  113. # Check that the filter class is correct
  114. filter = filters[filter_name]
  115. if filter_class is not None:
  116. self.assertIs(
  117. type(filter),
  118. filter_class,
  119. f"Invalid filter class {type(filter)} for {filter_name} (should be {filter_class})!"
  120. )
  121. class ChangeLoggedFilterSetTests(BaseFilterSetTests):
  122. def test_created(self):
  123. pk_list = self.queryset.values_list('pk', flat=True)[:2]
  124. self.queryset.filter(pk__in=pk_list).update(created=datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone.utc))
  125. params = {'created': ['2021-01-01T00:00:00']}
  126. self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
  127. def test_last_updated(self):
  128. pk_list = self.queryset.values_list('pk', flat=True)[:2]
  129. self.queryset.filter(pk__in=pk_list).update(last_updated=datetime(2021, 1, 2, 0, 0, 0, tzinfo=timezone.utc))
  130. params = {'last_updated': ['2021-01-02T00:00:00']}
  131. self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)