Quellcode durchsuchen

Validate filter class for foreign key fields

Jeremy Stretch vor 1 Jahr
Ursprung
Commit
a136030094

+ 8 - 3
netbox/dcim/filtersets.py

@@ -785,10 +785,13 @@ class FrontPortTemplateFilterSet(ChangeLoggedModelFilterSet, ModularDeviceTypeCo
         choices=PortTypeChoices,
         choices=PortTypeChoices,
         null_value=None
         null_value=None
     )
     )
+    rear_port_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=RearPort.objects.all()
+    )
 
 
     class Meta:
     class Meta:
         model = FrontPortTemplate
         model = FrontPortTemplate
-        fields = ('id', 'name', 'label', 'type', 'color', 'rear_port_id', 'rear_port_position', 'description')
+        fields = ('id', 'name', 'label', 'type', 'color', 'rear_port_position', 'description')
 
 
 
 
 class RearPortTemplateFilterSet(ChangeLoggedModelFilterSet, ModularDeviceTypeComponentFilterSet):
 class RearPortTemplateFilterSet(ChangeLoggedModelFilterSet, ModularDeviceTypeComponentFilterSet):
@@ -1688,12 +1691,14 @@ class FrontPortFilterSet(
         choices=PortTypeChoices,
         choices=PortTypeChoices,
         null_value=None
         null_value=None
     )
     )
+    rear_port_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=RearPort.objects.all()
+    )
 
 
     class Meta:
     class Meta:
         model = FrontPort
         model = FrontPort
         fields = (
         fields = (
-            'id', 'name', 'label', 'type', 'color', 'rear_port_id', 'rear_port_position', 'description',
-            'mark_connected', 'cable_end',
+            'id', 'name', 'label', 'type', 'color', 'rear_port_position', 'description', 'mark_connected', 'cable_end',
         )
         )
 
 
 
 

+ 6 - 1
netbox/ipam/filtersets.py

@@ -667,10 +667,15 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
         queryset=Service.objects.all(),
         queryset=Service.objects.all(),
         label=_('Service (ID)'),
         label=_('Service (ID)'),
     )
     )
+    nat_inside_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='nat_inside',
+        queryset=IPAddress.objects.all(),
+        label=_('NAT inside IP address (ID)'),
+    )
 
 
     class Meta:
     class Meta:
         model = IPAddress
         model = IPAddress
-        fields = ('id', 'dns_name', 'description', 'assigned_object_type', 'assigned_object_id', 'nat_inside_id')
+        fields = ('id', 'dns_name', 'description', 'assigned_object_type', 'assigned_object_id')
 
 
     def search(self, queryset, name, value):
     def search(self, queryset, name, value):
         if not value.strip():
         if not value.strip():

+ 37 - 9
netbox/utilities/testing/filtersets.py

@@ -1,11 +1,14 @@
+import django_filters
 from datetime import datetime, timezone
 from datetime import datetime, timezone
 from itertools import chain
 from itertools import chain
+from mptt.models import MPTTModel
 
 
 from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
 from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
 from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
 from django.utils.module_loading import import_string
 from django.utils.module_loading import import_string
 from taggit.managers import TaggableManager
 from taggit.managers import TaggableManager
+from utilities.filters import TreeNodeMultipleChoiceFilter
 
 
 from core.models import ObjectType
 from core.models import ObjectType
 
 
@@ -52,6 +55,21 @@ class BaseFilterSetTests:
         related_model_name = field.related_model._meta.verbose_name
         related_model_name = field.related_model._meta.verbose_name
         return related_model_name.lower().replace(' ', '_')
         return related_model_name.lower().replace(' ', '_')
 
 
+    @staticmethod
+    def get_filter_class_for_field(field):
+
+        # ForeignKey & OneToOneField
+        if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
+
+            # ForeignKey to an MPTT-enabled model
+            if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
+                return TreeNodeMultipleChoiceFilter
+
+            return django_filters.ModelMultipleChoiceFilter
+
+        # Unable to determine the correct filter class
+        return None
+
     def test_id(self):
     def test_id(self):
         """
         """
         Test filtering for two PKs from a set of >2 objects.
         Test filtering for two PKs from a set of >2 objects.
@@ -76,7 +94,7 @@ class BaseFilterSetTests:
         filterset = import_string(f'{app_label}.filtersets.{model_name}FilterSet')
         filterset = import_string(f'{app_label}.filtersets.{model_name}FilterSet')
         self.assertEqual(model, filterset.Meta.model, "FilterSet model does not match!")
         self.assertEqual(model, filterset.Meta.model, "FilterSet model does not match!")
 
 
-        filterset_fields = sorted(filterset.get_filters())
+        filters = filterset.get_filters()
 
 
         # Check for missing filters
         # Check for missing filters
         for model_field in model._meta.get_fields():
         for model_field in model._meta.get_fields():
@@ -95,26 +113,36 @@ class BaseFilterSetTests:
 
 
             # One-to-one & one-to-many relationships
             # One-to-one & one-to-many relationships
             if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
             if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
+
+                # Relationships to ContentType (used as part of a GFK) do not need a filter
                 if model_field.related_model is ContentType:
                 if model_field.related_model is ContentType:
-                    # Relationships to ContentType (used as part of a GFK) do not need a filter
                     continue
                     continue
-                elif model_field.related_model is ObjectType:
-                    # Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
+
+                # Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
+                if model_field.related_model is ObjectType:
                     filter_name = model_field.name
                     filter_name = model_field.name
                 else:
                 else:
                     filter_name = f'{model_field.name}_id'
                     filter_name = f'{model_field.name}_id'
+
                 self.assertIn(
                 self.assertIn(
                     filter_name,
                     filter_name,
-                    filterset_fields,
+                    filters,
                     f'No filter defined for {filter_name} ({model_field.name})!'
                     f'No filter defined for {filter_name} ({model_field.name})!'
                 )
                 )
-
+                if filter_class := self.get_filter_class_for_field(model_field):
+                    self.assertIs(
+                        type(filters[filter_name]),
+                        filter_class,
+                        f"Invalid filter class for {filter_name}!"
+                    )
+
+            # Many-to-many relationships (forward & backward)
             elif type(model_field) in (ManyToManyField, ManyToManyRel):
             elif type(model_field) in (ManyToManyField, ManyToManyRel):
                 filter_name = self.get_m2m_filter_name(model_field)
                 filter_name = self.get_m2m_filter_name(model_field)
                 filter_name = f'{filter_name}_id'
                 filter_name = f'{filter_name}_id'
                 self.assertIn(
                 self.assertIn(
                     filter_name,
                     filter_name,
-                    filterset_fields,
+                    filters,
                     f'No filter defined for {filter_name} ({model_field.name})!'
                     f'No filter defined for {filter_name} ({model_field.name})!'
                 )
                 )
 
 
@@ -124,13 +152,13 @@ class BaseFilterSetTests:
 
 
             # Tags
             # Tags
             elif type(model_field) is TaggableManager:
             elif type(model_field) is TaggableManager:
-                self.assertIn('tag', filterset_fields, f'No filter defined for {model_field.name}!')
+                self.assertIn('tag', filters, f'No filter defined for {model_field.name}!')
 
 
             # All other fields
             # All other fields
             else:
             else:
                 self.assertIn(
                 self.assertIn(
                     model_field.name,
                     model_field.name,
-                    filterset_fields,
+                    filters,
                     f'No defined found for {model_field.name} ({type(model_field)})!'
                     f'No defined found for {model_field.name} ({type(model_field)})!'
                 )
                 )
 
 

+ 6 - 2
netbox/wireless/filtersets.py

@@ -87,8 +87,12 @@ class WirelessLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
 
 
 
 
 class WirelessLinkFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
 class WirelessLinkFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
-    interface_a_id = MultiValueNumberFilter()
-    interface_b_id = MultiValueNumberFilter()
+    interface_a_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=Interface.objects.all()
+    )
+    interface_b_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=Interface.objects.all()
+    )
     status = django_filters.MultipleChoiceFilter(
     status = django_filters.MultipleChoiceFilter(
         choices=LinkStatusChoices
         choices=LinkStatusChoices
     )
     )