Просмотр исходного кода

Extend logic for validating filter class

Jeremy Stretch 1 год назад
Родитель
Сommit
313e63622b

+ 18 - 12
netbox/extras/filtersets.py

@@ -91,8 +91,9 @@ class EventRuleFilterSet(NetBoxModelFilterSet):
         method='search',
         method='search',
         label=_('Search'),
         label=_('Search'),
     )
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     )
     object_type = ContentTypeFilter(
     object_type = ContentTypeFilter(
         field_name='object_types'
         field_name='object_types'
@@ -128,14 +129,16 @@ class CustomFieldFilterSet(ChangeLoggedModelFilterSet):
     type = django_filters.MultipleChoiceFilter(
     type = django_filters.MultipleChoiceFilter(
         choices=CustomFieldTypeChoices
         choices=CustomFieldTypeChoices
     )
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     )
     object_type = ContentTypeFilter(
     object_type = ContentTypeFilter(
         field_name='object_types'
         field_name='object_types'
     )
     )
-    related_object_type_id = MultiValueNumberFilter(
-        field_name='related_object_type__id'
+    related_object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='related_object_type'
     )
     )
     related_object_type = ContentTypeFilter()
     related_object_type = ContentTypeFilter()
     choice_set_id = django_filters.ModelMultipleChoiceFilter(
     choice_set_id = django_filters.ModelMultipleChoiceFilter(
@@ -199,8 +202,9 @@ class CustomLinkFilterSet(ChangeLoggedModelFilterSet):
         method='search',
         method='search',
         label=_('Search'),
         label=_('Search'),
     )
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     )
     object_type = ContentTypeFilter(
     object_type = ContentTypeFilter(
         field_name='object_types'
         field_name='object_types'
@@ -228,8 +232,9 @@ class ExportTemplateFilterSet(ChangeLoggedModelFilterSet):
         method='search',
         method='search',
         label=_('Search'),
         label=_('Search'),
     )
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     )
     object_type = ContentTypeFilter(
     object_type = ContentTypeFilter(
         field_name='object_types'
         field_name='object_types'
@@ -264,8 +269,9 @@ class SavedFilterFilterSet(ChangeLoggedModelFilterSet):
         method='search',
         method='search',
         label=_('Search'),
         label=_('Search'),
     )
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     )
     object_type = ContentTypeFilter(
     object_type = ContentTypeFilter(
         field_name='object_types'
         field_name='object_types'

+ 2 - 4
netbox/ipam/tests/test_filtersets.py

@@ -198,8 +198,7 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = VRF.objects.all()
     queryset = VRF.objects.all()
     filterset = VRFFilterSet
     filterset = VRFFilterSet
 
 
-    @staticmethod
-    def get_m2m_filter_name(field):
+    def get_m2m_filter_name(self, field):
         # Override filter names for import & export RouteTargets
         # Override filter names for import & export RouteTargets
         if field.name == 'import_targets':
         if field.name == 'import_targets':
             return 'import_target'
             return 'import_target'
@@ -303,8 +302,7 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = RouteTarget.objects.all()
     queryset = RouteTarget.objects.all()
     filterset = RouteTargetFilterSet
     filterset = RouteTargetFilterSet
 
 
-    @staticmethod
-    def get_m2m_filter_name(field):
+    def get_m2m_filter_name(self, field):
         # Override filter names for import & export VRFs and L2VPNs
         # Override filter names for import & export VRFs and L2VPNs
         if field.name == 'importing_vrfs':
         if field.name == 'importing_vrfs':
             return 'importing_vrf'
             return 'importing_vrf'

+ 5 - 3
netbox/users/filtersets.py

@@ -3,9 +3,10 @@ from django.contrib.auth import get_user_model
 from django.db.models import Q
 from django.db.models import Q
 from django.utils.translation import gettext as _
 from django.utils.translation import gettext as _
 
 
+from core.models import ObjectType
 from netbox.filtersets import BaseFilterSet
 from netbox.filtersets import BaseFilterSet
 from users.models import Group, ObjectPermission, Token
 from users.models import Group, ObjectPermission, Token
-from utilities.filters import ContentTypeFilter, MultiValueNumberFilter
+from utilities.filters import ContentTypeFilter
 
 
 __all__ = (
 __all__ = (
     'GroupFilterSet',
     'GroupFilterSet',
@@ -134,8 +135,9 @@ class ObjectPermissionFilterSet(BaseFilterSet):
         method='search',
         method='search',
         label=_('Search'),
         label=_('Search'),
     )
     )
-    object_type_id = MultiValueNumberFilter(
-        field_name='object_types__id'
+    object_type_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=ObjectType.objects.all(),
+        field_name='object_types'
     )
     )
     object_type = ContentTypeFilter(
     object_type = ContentTypeFilter(
         field_name='object_types'
         field_name='object_types'

+ 55 - 49
netbox/utilities/testing/filtersets.py

@@ -8,7 +8,9 @@ from django.contrib.contenttypes.models import ContentType
 from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
 from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
 from django.utils.module_loading import import_string
 from django.utils.module_loading import import_string
 from taggit.managers import TaggableManager
 from taggit.managers import TaggableManager
-from utilities.filters import TreeNodeMultipleChoiceFilter
+
+from extras.filters import TagFilter
+from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter
 
 
 from core.models import ObjectType
 from core.models import ObjectType
 
 
@@ -46,8 +48,7 @@ class BaseFilterSetTests:
     filterset = None
     filterset = None
     ignore_fields = tuple()
     ignore_fields = tuple()
 
 
-    @staticmethod
-    def get_m2m_filter_name(field):
+    def get_m2m_filter_name(self, field):
         """
         """
         Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
         Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
         cases may override this method to prescribe deviations for specific fields.
         cases may override this method to prescribe deviations for specific fields.
@@ -55,20 +56,50 @@ class BaseFilterSetTests:
         related_model_name = field.related_model._meta.verbose_name
         related_model_name = field.related_model._meta.verbose_name
         return related_model_name.lower().replace(' ', '_')
         return related_model_name.lower().replace(' ', '_')
 
 
-    @staticmethod
-    def get_filter_class_for_field(field):
-
+    def get_filters_for_model_field(self, field):
+        """
+        Given a model field, return an iterable of (name, class) for each filter that should be defined on
+        the model's FilterSet class. If the appropriate filter class cannot be determined, it will be None.
+        """
         # ForeignKey & OneToOneField
         # ForeignKey & OneToOneField
         if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
         if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
 
 
+            # Relationships to ContentType (used as part of a GFK) do not need a filter
+            if field.related_model is ContentType:
+                return [(None, None)]
+
+            # ForeignKeys to ObjectType need two filters: 'app.model' & PK
+            if field.related_model is ObjectType:
+                return [
+                    (field.name, ContentTypeFilter),
+                    (f'{field.name}_id', django_filters.ModelMultipleChoiceFilter),
+                ]
+
             # ForeignKey to an MPTT-enabled model
             # ForeignKey to an MPTT-enabled model
             if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
             if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
-                return TreeNodeMultipleChoiceFilter
+                return [(f'{field.name}_id', TreeNodeMultipleChoiceFilter)]
+
+            return [(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter)]
 
 
-            return django_filters.ModelMultipleChoiceFilter
+        # Many-to-many relationships (forward & backward)
+        elif type(field) in (ManyToManyField, ManyToManyRel):
+            filter_name = self.get_m2m_filter_name(field)
+
+            # ManyToManyFields to ObjectType need two filters: 'app.model' & PK
+            if field.related_model is ObjectType:
+                return [
+                    (filter_name, ContentTypeFilter),
+                    (f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter),
+                ]
+
+            return [(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter)]
+
+        # Tag manager
+        if type(field) is TaggableManager:
+            return [('tag', TagFilter)]
 
 
         # Unable to determine the correct filter class
         # Unable to determine the correct filter class
-        return None
+        return [(field.name, None)]
 
 
     def test_id(self):
     def test_id(self):
         """
         """
@@ -111,57 +142,32 @@ class BaseFilterSetTests:
             if type(model_field) is ManyToOneRel:
             if type(model_field) is ManyToOneRel:
                 continue
                 continue
 
 
-            # One-to-one & one-to-many relationships
-            if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
+            # TODO: Generic relationships
+            if type(model_field) in (GenericForeignKey, GenericRelation):
+                continue
 
 
-                # Relationships to ContentType (used as part of a GFK) do not need a filter
-                if model_field.related_model is ContentType:
-                    continue
+            for filter_name, filter_class in self.get_filters_for_model_field(model_field):
 
 
-                # Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
-                if model_field.related_model is ObjectType:
-                    filter_name = model_field.name
-                else:
-                    filter_name = f'{model_field.name}_id'
+                if filter_name is None:
+                    # Field is exempt
+                    continue
 
 
+                # Check that the filter is defined
                 self.assertIn(
                 self.assertIn(
                     filter_name,
                     filter_name,
-                    filters,
+                    filters.keys(),
                     f'No filter defined for {filter_name} ({model_field.name})!'
                     f'No filter defined for {filter_name} ({model_field.name})!'
                 )
                 )
-                if filter_class := self.get_filter_class_for_field(model_field):
+
+                # Check that the filter class is correct
+                filter = filters[filter_name]
+                if filter_class is not None:
                     self.assertIs(
                     self.assertIs(
-                        type(filters[filter_name]),
+                        type(filter),
                         filter_class,
                         filter_class,
-                        f"Invalid filter class for {filter_name}!"
+                        f"Invalid filter class {type(filter)} for {filter_name} (should be {filter_class})!"
                     )
                     )
 
 
-            # Many-to-many relationships (forward & backward)
-            elif type(model_field) in (ManyToManyField, ManyToManyRel):
-                filter_name = self.get_m2m_filter_name(model_field)
-                filter_name = f'{filter_name}_id'
-                self.assertIn(
-                    filter_name,
-                    filters,
-                    f'No filter defined for {filter_name} ({model_field.name})!'
-                )
-
-            # TODO: Generic relationships
-            elif type(model_field) in (GenericForeignKey, GenericRelation):
-                continue
-
-            # Tags
-            elif type(model_field) is TaggableManager:
-                self.assertIn('tag', filters, f'No filter defined for {model_field.name}!')
-
-            # All other fields
-            else:
-                self.assertIn(
-                    model_field.name,
-                    filters,
-                    f'No defined found for {model_field.name} ({type(model_field)})!'
-                )
-
 
 
 class ChangeLoggedFilterSetTests(BaseFilterSetTests):
 class ChangeLoggedFilterSetTests(BaseFilterSetTests):
 
 

+ 14 - 8
netbox/vpn/filtersets.py

@@ -169,11 +169,14 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet):
     mode = django_filters.MultipleChoiceFilter(
     mode = django_filters.MultipleChoiceFilter(
         choices=IKEModeChoices
         choices=IKEModeChoices
     )
     )
-    ike_proposal_id = MultiValueNumberFilter(
-        field_name='proposals__id'
+    ike_proposal_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='proposals',
+        queryset=IKEProposal.objects.all()
     )
     )
-    ike_proposal = MultiValueCharFilter(
-        field_name='proposals__name'
+    ike_proposal = django_filters.ModelMultipleChoiceFilter(
+        field_name='proposals__name',
+        queryset=IKEProposal.objects.all(),
+        to_field_name='name'
     )
     )
 
 
     # TODO: Remove in v4.1
     # TODO: Remove in v4.1
@@ -231,11 +234,14 @@ class IPSecPolicyFilterSet(NetBoxModelFilterSet):
     pfs_group = django_filters.MultipleChoiceFilter(
     pfs_group = django_filters.MultipleChoiceFilter(
         choices=DHGroupChoices
         choices=DHGroupChoices
     )
     )
-    ipsec_proposal_id = MultiValueNumberFilter(
-        field_name='proposals__id'
+    ipsec_proposal_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='proposals',
+        queryset=IPSecProposal.objects.all()
     )
     )
-    ipsec_proposal = MultiValueCharFilter(
-        field_name='proposals__name'
+    ipsec_proposal = django_filters.ModelMultipleChoiceFilter(
+        field_name='proposals__name',
+        queryset=IPSecProposal.objects.all(),
+        to_field_name='name'
     )
     )
 
 
     # TODO: Remove in v4.1
     # TODO: Remove in v4.1

+ 1 - 2
netbox/vpn/tests/test_filtersets.py

@@ -743,8 +743,7 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = L2VPN.objects.all()
     queryset = L2VPN.objects.all()
     filterset = L2VPNFilterSet
     filterset = L2VPNFilterSet
 
 
-    @staticmethod
-    def get_m2m_filter_name(field):
+    def get_m2m_filter_name(self, field):
         # Override filter names for import & export RouteTargets
         # Override filter names for import & export RouteTargets
         if field.name == 'import_targets':
         if field.name == 'import_targets':
             return 'import_target'
             return 'import_target'