|
|
@@ -1,11 +1,14 @@
|
|
|
+import django_filters
|
|
|
from datetime import datetime, timezone
|
|
|
from itertools import chain
|
|
|
+from mptt.models import MPTTModel
|
|
|
|
|
|
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
|
|
|
from django.contrib.contenttypes.models import ContentType
|
|
|
from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
|
|
|
from django.utils.module_loading import import_string
|
|
|
from taggit.managers import TaggableManager
|
|
|
+from utilities.filters import TreeNodeMultipleChoiceFilter
|
|
|
|
|
|
from core.models import ObjectType
|
|
|
|
|
|
@@ -52,6 +55,21 @@ class BaseFilterSetTests:
|
|
|
related_model_name = field.related_model._meta.verbose_name
|
|
|
return related_model_name.lower().replace(' ', '_')
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def get_filter_class_for_field(field):
|
|
|
+
|
|
|
+ # ForeignKey & OneToOneField
|
|
|
+ if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
|
|
|
+
|
|
|
+ # ForeignKey to an MPTT-enabled model
|
|
|
+ if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
|
|
|
+ return TreeNodeMultipleChoiceFilter
|
|
|
+
|
|
|
+ return django_filters.ModelMultipleChoiceFilter
|
|
|
+
|
|
|
+ # Unable to determine the correct filter class
|
|
|
+ return None
|
|
|
+
|
|
|
def test_id(self):
|
|
|
"""
|
|
|
Test filtering for two PKs from a set of >2 objects.
|
|
|
@@ -76,7 +94,7 @@ class BaseFilterSetTests:
|
|
|
filterset = import_string(f'{app_label}.filtersets.{model_name}FilterSet')
|
|
|
self.assertEqual(model, filterset.Meta.model, "FilterSet model does not match!")
|
|
|
|
|
|
- filterset_fields = sorted(filterset.get_filters())
|
|
|
+ filters = filterset.get_filters()
|
|
|
|
|
|
# Check for missing filters
|
|
|
for model_field in model._meta.get_fields():
|
|
|
@@ -95,26 +113,36 @@ class BaseFilterSetTests:
|
|
|
|
|
|
# One-to-one & one-to-many relationships
|
|
|
if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
|
|
|
+
|
|
|
+ # Relationships to ContentType (used as part of a GFK) do not need a filter
|
|
|
if model_field.related_model is ContentType:
|
|
|
- # Relationships to ContentType (used as part of a GFK) do not need a filter
|
|
|
continue
|
|
|
- elif model_field.related_model is ObjectType:
|
|
|
- # Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
|
|
|
+
|
|
|
+ # Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
|
|
|
+ if model_field.related_model is ObjectType:
|
|
|
filter_name = model_field.name
|
|
|
else:
|
|
|
filter_name = f'{model_field.name}_id'
|
|
|
+
|
|
|
self.assertIn(
|
|
|
filter_name,
|
|
|
- filterset_fields,
|
|
|
+ filters,
|
|
|
f'No filter defined for {filter_name} ({model_field.name})!'
|
|
|
)
|
|
|
-
|
|
|
+ if filter_class := self.get_filter_class_for_field(model_field):
|
|
|
+ self.assertIs(
|
|
|
+ type(filters[filter_name]),
|
|
|
+ filter_class,
|
|
|
+ f"Invalid filter class for {filter_name}!"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Many-to-many relationships (forward & backward)
|
|
|
elif type(model_field) in (ManyToManyField, ManyToManyRel):
|
|
|
filter_name = self.get_m2m_filter_name(model_field)
|
|
|
filter_name = f'{filter_name}_id'
|
|
|
self.assertIn(
|
|
|
filter_name,
|
|
|
- filterset_fields,
|
|
|
+ filters,
|
|
|
f'No filter defined for {filter_name} ({model_field.name})!'
|
|
|
)
|
|
|
|
|
|
@@ -124,13 +152,13 @@ class BaseFilterSetTests:
|
|
|
|
|
|
# Tags
|
|
|
elif type(model_field) is TaggableManager:
|
|
|
- self.assertIn('tag', filterset_fields, f'No filter defined for {model_field.name}!')
|
|
|
+ self.assertIn('tag', filters, f'No filter defined for {model_field.name}!')
|
|
|
|
|
|
# All other fields
|
|
|
else:
|
|
|
self.assertIn(
|
|
|
model_field.name,
|
|
|
- filterset_fields,
|
|
|
+ filters,
|
|
|
f'No defined found for {model_field.name} ({type(model_field)})!'
|
|
|
)
|
|
|
|