Просмотр исходного кода

Extend logic for validating filter class

Jeremy Stretch 1 год назад
Родитель
Сommit
313e63622b

+ 18 - 12
netbox/extras/filtersets.py

@@ -91,8 +91,9 @@ class EventRuleFilterSet(NetBoxModelFilterSet):
         method='search',
         label=_('Search'),
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     object_type = ContentTypeFilter(
         field_name='object_types'
@@ -128,14 +129,16 @@ class CustomFieldFilterSet(ChangeLoggedModelFilterSet):
     type = django_filters.MultipleChoiceFilter(
         choices=CustomFieldTypeChoices
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     object_type = ContentTypeFilter(
         field_name='object_types'
     )
-    related_object_type_id = MultiValueNumberFilter(
-        field_name='related_object_type__id'
+    related_object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='related_object_type'
     )
     related_object_type = ContentTypeFilter()
     choice_set_id = django_filters.ModelMultipleChoiceFilter(
@@ -199,8 +202,9 @@ class CustomLinkFilterSet(ChangeLoggedModelFilterSet):
         method='search',
         label=_('Search'),
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     object_type = ContentTypeFilter(
         field_name='object_types'
@@ -228,8 +232,9 @@ class ExportTemplateFilterSet(ChangeLoggedModelFilterSet):
         method='search',
         label=_('Search'),
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     object_type = ContentTypeFilter(
         field_name='object_types'
@@ -264,8 +269,9 @@ class SavedFilterFilterSet(ChangeLoggedModelFilterSet):
         method='search',
         label=_('Search'),
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     object_type = ContentTypeFilter(
         field_name='object_types'

+ 2 - 4
netbox/ipam/tests/test_filtersets.py

@@ -198,8 +198,7 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = VRF.objects.all()
     filterset = VRFFilterSet
 
-    @staticmethod
-    def get_m2m_filter_name(field):
+    def get_m2m_filter_name(self, field):
         # Override filter names for import & export RouteTargets
         if field.name == 'import_targets':
             return 'import_target'
@@ -303,8 +302,7 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = RouteTarget.objects.all()
     filterset = RouteTargetFilterSet
 
-    @staticmethod
-    def get_m2m_filter_name(field):
+    def get_m2m_filter_name(self, field):
         # Override filter names for import & export VRFs and L2VPNs
         if field.name == 'importing_vrfs':
             return 'importing_vrf'

+ 5 - 3
netbox/users/filtersets.py

@@ -3,9 +3,10 @@ from django.contrib.auth import get_user_model
 from django.db.models import Q
 from django.utils.translation import gettext as _
 
+from core.models import ObjectType
 from netbox.filtersets import BaseFilterSet
 from users.models import Group, ObjectPermission, Token
-from utilities.filters import ContentTypeFilter, MultiValueNumberFilter
+from utilities.filters import ContentTypeFilter
 
 __all__ = (
     'GroupFilterSet',
@@ -134,8 +135,9 @@ class ObjectPermissionFilterSet(BaseFilterSet):
         method='search',
         label=_('Search'),
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     object_type = ContentTypeFilter(
         field_name='object_types'

+ 55 - 49
netbox/utilities/testing/filtersets.py

@@ -8,7 +8,9 @@ from django.contrib.contenttypes.models import ContentType
 from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
 from django.utils.module_loading import import_string
 from taggit.managers import TaggableManager
-from utilities.filters import TreeNodeMultipleChoiceFilter
+
+from extras.filters import TagFilter
+from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter
 
 from core.models import ObjectType
 
@@ -46,8 +48,7 @@ class BaseFilterSetTests:
     filterset = None
     ignore_fields = tuple()
 
-    @staticmethod
-    def get_m2m_filter_name(field):
+    def get_m2m_filter_name(self, field):
         """
         Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
         cases may override this method to prescribe deviations for specific fields.
@@ -55,20 +56,50 @@ class BaseFilterSetTests:
         related_model_name = field.related_model._meta.verbose_name
         return related_model_name.lower().replace(' ', '_')
 
-    @staticmethod
-    def get_filter_class_for_field(field):
-
+    def get_filters_for_model_field(self, field):
+        """
+        Given a model field, return an iterable of (name, class) for each filter that should be defined on
+        the model's FilterSet class. If the appropriate filter class cannot be determined, it will be None.
+        """
         # ForeignKey & OneToOneField
         if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
 
+            # Relationships to ContentType (used as part of a GFK) do not need a filter
+            if field.related_model is ContentType:
+                return [(None, None)]
+
+            # ForeignKeys to ObjectType need two filters: 'app.model' & PK
+            if field.related_model is ObjectType:
+                return [
+                    (field.name, ContentTypeFilter),
+                    (f'{field.name}_id', django_filters.ModelMultipleChoiceFilter),
+                ]
+
             # ForeignKey to an MPTT-enabled model
             if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
-                return TreeNodeMultipleChoiceFilter
+                return [(f'{field.name}_id', TreeNodeMultipleChoiceFilter)]
+
+            return [(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter)]
 
-            return django_filters.ModelMultipleChoiceFilter
+        # Many-to-many relationships (forward & backward)
+        elif type(field) in (ManyToManyField, ManyToManyRel):
+            filter_name = self.get_m2m_filter_name(field)
+
+            # ManyToManyFields to ObjectType need two filters: 'app.model' & PK
+            if field.related_model is ObjectType:
+                return [
+                    (filter_name, ContentTypeFilter),
+                    (f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter),
+                ]
+
+            return [(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter)]
+
+        # Tag manager
+        if type(field) is TaggableManager:
+            return [('tag', TagFilter)]
 
         # Unable to determine the correct filter class
-        return None
+        return [(field.name, None)]
 
     def test_id(self):
         """
@@ -111,57 +142,32 @@ class BaseFilterSetTests:
             if type(model_field) is ManyToOneRel:
                 continue
 
-            # One-to-one & one-to-many relationships
-            if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
+            # TODO: Generic relationships
+            if type(model_field) in (GenericForeignKey, GenericRelation):
+                continue
 
-                # Relationships to ContentType (used as part of a GFK) do not need a filter
-                if model_field.related_model is ContentType:
-                    continue
+            for filter_name, filter_class in self.get_filters_for_model_field(model_field):
 
-                # Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
-                if model_field.related_model is ObjectType:
-                    filter_name = model_field.name
-                else:
-                    filter_name = f'{model_field.name}_id'
+                if filter_name is None:
+                    # Field is exempt
+                    continue
 
+                # Check that the filter is defined
                 self.assertIn(
                     filter_name,
-                    filters,
+                    filters.keys(),
                     f'No filter defined for {filter_name} ({model_field.name})!'
                 )
-                if filter_class := self.get_filter_class_for_field(model_field):
+
+                # Check that the filter class is correct
+                filter = filters[filter_name]
+                if filter_class is not None:
                     self.assertIs(
-                        type(filters[filter_name]),
+                        type(filter),
                         filter_class,
-                        f"Invalid filter class for {filter_name}!"
+                        f"Invalid filter class {type(filter)} for {filter_name} (should be {filter_class})!"
                     )
 
-            # Many-to-many relationships (forward & backward)
-            elif type(model_field) in (ManyToManyField, ManyToManyRel):
-                filter_name = self.get_m2m_filter_name(model_field)
-                filter_name = f'{filter_name}_id'
-                self.assertIn(
-                    filter_name,
-                    filters,
-                    f'No filter defined for {filter_name} ({model_field.name})!'
-                )
-
-            # TODO: Generic relationships
-            elif type(model_field) in (GenericForeignKey, GenericRelation):
-                continue
-
-            # Tags
-            elif type(model_field) is TaggableManager:
-                self.assertIn('tag', filters, f'No filter defined for {model_field.name}!')
-
-            # All other fields
-            else:
-                self.assertIn(
-                    model_field.name,
-                    filters,
-                    f'No defined found for {model_field.name} ({type(model_field)})!'
-                )
-
 
 class ChangeLoggedFilterSetTests(BaseFilterSetTests):
 

+ 14 - 8
netbox/vpn/filtersets.py

@@ -169,11 +169,14 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet):
     mode = django_filters.MultipleChoiceFilter(
         choices=IKEModeChoices
     )
-    ike_proposal_id = MultiValueNumberFilter(
-        field_name='proposals__id'
+    ike_proposal_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='proposals',
+        queryset=IKEProposal.objects.all()
     )
-    ike_proposal = MultiValueCharFilter(
-        field_name='proposals__name'
+    ike_proposal = django_filters.ModelMultipleChoiceFilter(
+        field_name='proposals__name',
+        queryset=IKEProposal.objects.all(),
+        to_field_name='name'
     )
 
     # TODO: Remove in v4.1
@@ -231,11 +234,14 @@ class IPSecPolicyFilterSet(NetBoxModelFilterSet):
     pfs_group = django_filters.MultipleChoiceFilter(
         choices=DHGroupChoices
     )
-    ipsec_proposal_id = MultiValueNumberFilter(
-        field_name='proposals__id'
+    ipsec_proposal_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='proposals',
+        queryset=IPSecProposal.objects.all()
     )
-    ipsec_proposal = MultiValueCharFilter(
-        field_name='proposals__name'
+    ipsec_proposal = django_filters.ModelMultipleChoiceFilter(
+        field_name='proposals__name',
+        queryset=IPSecProposal.objects.all(),
+        to_field_name='name'
     )
 
     # TODO: Remove in v4.1

+ 1 - 2
netbox/vpn/tests/test_filtersets.py

@@ -743,8 +743,7 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = L2VPN.objects.all()
     filterset = L2VPNFilterSet
 
-    @staticmethod
-    def get_m2m_filter_name(field):
+    def get_m2m_filter_name(self, field):
         # Override filter names for import & export RouteTargets
         if field.name == 'import_targets':
             return 'import_target'