Просмотр исходного кода

Validate filter class for foreign key fields

Jeremy Stretch 1 год назад
Родитель
Сommit
a136030094

+ 8 - 3
netbox/dcim/filtersets.py

@@ -785,10 +785,13 @@ class FrontPortTemplateFilterSet(ChangeLoggedModelFilterSet, ModularDeviceTypeCo
         choices=PortTypeChoices,
         null_value=None
     )
+    rear_port_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=RearPort.objects.all()
+    )
 
     class Meta:
         model = FrontPortTemplate
-        fields = ('id', 'name', 'label', 'type', 'color', 'rear_port_id', 'rear_port_position', 'description')
+        fields = ('id', 'name', 'label', 'type', 'color', 'rear_port_position', 'description')
 
 
 class RearPortTemplateFilterSet(ChangeLoggedModelFilterSet, ModularDeviceTypeComponentFilterSet):
@@ -1688,12 +1691,14 @@ class FrontPortFilterSet(
         choices=PortTypeChoices,
         null_value=None
     )
+    rear_port_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=RearPort.objects.all()
+    )
 
     class Meta:
         model = FrontPort
         fields = (
-            'id', 'name', 'label', 'type', 'color', 'rear_port_id', 'rear_port_position', 'description',
-            'mark_connected', 'cable_end',
+            'id', 'name', 'label', 'type', 'color', 'rear_port_position', 'description', 'mark_connected', 'cable_end',
         )
 
 

+ 6 - 1
netbox/ipam/filtersets.py

@@ -667,10 +667,15 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
         queryset=Service.objects.all(),
         label=_('Service (ID)'),
     )
+    nat_inside_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='nat_inside',
+        queryset=IPAddress.objects.all(),
+        label=_('NAT inside IP address (ID)'),
+    )
 
     class Meta:
         model = IPAddress
-        fields = ('id', 'dns_name', 'description', 'assigned_object_type', 'assigned_object_id', 'nat_inside_id')
+        fields = ('id', 'dns_name', 'description', 'assigned_object_type', 'assigned_object_id')
 
     def search(self, queryset, name, value):
         if not value.strip():

+ 37 - 9
netbox/utilities/testing/filtersets.py

@@ -1,11 +1,14 @@
+import django_filters
 from datetime import datetime, timezone
 from itertools import chain
+from mptt.models import MPTTModel
 
 from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
 from django.contrib.contenttypes.models import ContentType
 from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
 from django.utils.module_loading import import_string
 from taggit.managers import TaggableManager
+from utilities.filters import TreeNodeMultipleChoiceFilter
 
 from core.models import ObjectType
 
@@ -52,6 +55,21 @@ class BaseFilterSetTests:
         related_model_name = field.related_model._meta.verbose_name
         return related_model_name.lower().replace(' ', '_')
 
+    @staticmethod
+    def get_filter_class_for_field(field):
+
+        # ForeignKey & OneToOneField
+        if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
+
+            # ForeignKey to an MPTT-enabled model
+            if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
+                return TreeNodeMultipleChoiceFilter
+
+            return django_filters.ModelMultipleChoiceFilter
+
+        # Unable to determine the correct filter class
+        return None
+
     def test_id(self):
         """
         Test filtering for two PKs from a set of >2 objects.
@@ -76,7 +94,7 @@ class BaseFilterSetTests:
         filterset = import_string(f'{app_label}.filtersets.{model_name}FilterSet')
         self.assertEqual(model, filterset.Meta.model, "FilterSet model does not match!")
 
-        filterset_fields = sorted(filterset.get_filters())
+        filters = filterset.get_filters()
 
         # Check for missing filters
         for model_field in model._meta.get_fields():
@@ -95,26 +113,36 @@ class BaseFilterSetTests:
 
             # One-to-one & one-to-many relationships
             if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
+
+                # Relationships to ContentType (used as part of a GFK) do not need a filter
                 if model_field.related_model is ContentType:
-                    # Relationships to ContentType (used as part of a GFK) do not need a filter
                     continue
-                elif model_field.related_model is ObjectType:
-                    # Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
+
+                # Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
+                if model_field.related_model is ObjectType:
                     filter_name = model_field.name
                 else:
                     filter_name = f'{model_field.name}_id'
+
                 self.assertIn(
                     filter_name,
-                    filterset_fields,
+                    filters,
                     f'No filter defined for {filter_name} ({model_field.name})!'
                 )
-
+                if filter_class := self.get_filter_class_for_field(model_field):
+                    self.assertIs(
+                        type(filters[filter_name]),
+                        filter_class,
+                        f"Invalid filter class for {filter_name}!"
+                    )
+
+            # Many-to-many relationships (forward & backward)
             elif type(model_field) in (ManyToManyField, ManyToManyRel):
                 filter_name = self.get_m2m_filter_name(model_field)
                 filter_name = f'{filter_name}_id'
                 self.assertIn(
                     filter_name,
-                    filterset_fields,
+                    filters,
                     f'No filter defined for {filter_name} ({model_field.name})!'
                 )
 
@@ -124,13 +152,13 @@ class BaseFilterSetTests:
 
             # Tags
             elif type(model_field) is TaggableManager:
-                self.assertIn('tag', filterset_fields, f'No filter defined for {model_field.name}!')
+                self.assertIn('tag', filters, f'No filter defined for {model_field.name}!')
 
             # All other fields
             else:
                 self.assertIn(
                     model_field.name,
-                    filterset_fields,
+                    filters,
                     f'No defined found for {model_field.name} ({type(model_field)})!'
                 )
 

+ 6 - 2
netbox/wireless/filtersets.py

@@ -87,8 +87,12 @@ class WirelessLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
 
 
 class WirelessLinkFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
-    interface_a_id = MultiValueNumberFilter()
-    interface_b_id = MultiValueNumberFilter()
+    interface_a_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=Interface.objects.all()
+    )
+    interface_b_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=Interface.objects.all()
+    )
     status = django_filters.MultipleChoiceFilter(
         choices=LinkStatusChoices
     )