Просмотр исходного кода

Closes #15383: Standardize filtering logic for the parents of recursively-nested models

Jeremy Stretch 1 год назад
Родитель
Сommit
d6acc18c29

+ 38 - 2
netbox/dcim/filtersets.py

@@ -89,6 +89,19 @@ class RegionFilterSet(OrganizationalModelFilterSet, ContactModelFilterSet):
         to_field_name='slug',
         label=_('Parent region (slug)'),
     )
+    ancestor_id = TreeNodeMultipleChoiceFilter(
+        queryset=Region.objects.all(),
+        field_name='parent',
+        lookup_expr='in',
+        label=_('Region (ID)'),
+    )
+    ancestor = TreeNodeMultipleChoiceFilter(
+        queryset=Region.objects.all(),
+        field_name='parent',
+        lookup_expr='in',
+        to_field_name='slug',
+        label=_('Region (slug)'),
+    )
 
     class Meta:
         model = Region
@@ -106,6 +119,19 @@ class SiteGroupFilterSet(OrganizationalModelFilterSet, ContactModelFilterSet):
         to_field_name='slug',
         label=_('Parent site group (slug)'),
     )
+    ancestor_id = TreeNodeMultipleChoiceFilter(
+        queryset=SiteGroup.objects.all(),
+        field_name='parent',
+        lookup_expr='in',
+        label=_('Site group (ID)'),
+    )
+    ancestor = TreeNodeMultipleChoiceFilter(
+        queryset=SiteGroup.objects.all(),
+        field_name='parent',
+        lookup_expr='in',
+        to_field_name='slug',
+        label=_('Site group (slug)'),
+    )
 
     class Meta:
         model = SiteGroup
@@ -214,13 +240,23 @@ class LocationFilterSet(TenancyFilterSet, ContactModelFilterSet, OrganizationalM
         to_field_name='slug',
         label=_('Site (slug)'),
     )
-    parent_id = TreeNodeMultipleChoiceFilter(
+    parent_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=Location.objects.all(),
+        label=_('Parent location (ID)'),
+    )
+    parent = django_filters.ModelMultipleChoiceFilter(
+        field_name='parent__slug',
+        queryset=Location.objects.all(),
+        to_field_name='slug',
+        label=_('Parent location (slug)'),
+    )
+    ancestor_id = TreeNodeMultipleChoiceFilter(
         queryset=Location.objects.all(),
         field_name='parent',
         lookup_expr='in',
         label=_('Location (ID)'),
     )
-    parent = TreeNodeMultipleChoiceFilter(
+    ancestor = TreeNodeMultipleChoiceFilter(
         queryset=Location.objects.all(),
         field_name='parent',
         lookup_expr='in',

+ 92 - 41
netbox/dcim/tests/test_filtersets.py

@@ -64,21 +64,32 @@ class RegionTestCase(TestCase, ChangeLoggedFilterSetTests):
     @classmethod
     def setUpTestData(cls):
 
-        regions = (
+        parent_regions = (
             Region(name='Region 1', slug='region-1', description='foobar1'),
             Region(name='Region 2', slug='region-2', description='foobar2'),
             Region(name='Region 3', slug='region-3', description='foobar3'),
         )
+        for region in parent_regions:
+            region.save()
+
+        regions = (
+            Region(name='Region 1A', slug='region-1a', parent=parent_regions[0]),
+            Region(name='Region 1B', slug='region-1b', parent=parent_regions[0]),
+            Region(name='Region 2A', slug='region-2a', parent=parent_regions[1]),
+            Region(name='Region 2B', slug='region-2b', parent=parent_regions[1]),
+            Region(name='Region 3A', slug='region-3a', parent=parent_regions[2]),
+            Region(name='Region 3B', slug='region-3b', parent=parent_regions[2]),
+        )
         for region in regions:
             region.save()
 
         child_regions = (
-            Region(name='Region 1A', slug='region-1a', parent=regions[0]),
-            Region(name='Region 1B', slug='region-1b', parent=regions[0]),
-            Region(name='Region 2A', slug='region-2a', parent=regions[1]),
-            Region(name='Region 2B', slug='region-2b', parent=regions[1]),
-            Region(name='Region 3A', slug='region-3a', parent=regions[2]),
-            Region(name='Region 3B', slug='region-3b', parent=regions[2]),
+            Region(name='Region 1A1', slug='region-1a1', parent=regions[0]),
+            Region(name='Region 1B1', slug='region-1b1', parent=regions[1]),
+            Region(name='Region 2A1', slug='region-2a1', parent=regions[2]),
+            Region(name='Region 2B1', slug='region-2b1', parent=regions[3]),
+            Region(name='Region 3A1', slug='region-3a1', parent=regions[4]),
+            Region(name='Region 3B1', slug='region-3b1', parent=regions[5]),
         )
         for region in child_regions:
             region.save()
@@ -100,12 +111,19 @@ class RegionTestCase(TestCase, ChangeLoggedFilterSetTests):
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
     def test_parent(self):
-        parent_regions = Region.objects.filter(parent__isnull=True)[:2]
-        params = {'parent_id': [parent_regions[0].pk, parent_regions[1].pk]}
+        regions = Region.objects.filter(parent__isnull=True)[:2]
+        params = {'parent_id': [regions[0].pk, regions[1].pk]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
-        params = {'parent': [parent_regions[0].slug, parent_regions[1].slug]}
+        params = {'parent': [regions[0].slug, regions[1].slug]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
 
+    def test_ancestor(self):
+        regions = Region.objects.filter(parent__isnull=True)[:2]
+        params = {'ancestor_id': [regions[0].pk, regions[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
+        params = {'ancestor': [regions[0].slug, regions[1].slug]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
+
 
 class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = SiteGroup.objects.all()
@@ -114,24 +132,35 @@ class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
     @classmethod
     def setUpTestData(cls):
 
-        sitegroups = (
+        parent_groups = (
             SiteGroup(name='Site Group 1', slug='site-group-1', description='foobar1'),
             SiteGroup(name='Site Group 2', slug='site-group-2', description='foobar2'),
             SiteGroup(name='Site Group 3', slug='site-group-3', description='foobar3'),
         )
-        for sitegroup in sitegroups:
-            sitegroup.save()
+        for site_group in parent_groups:
+            site_group.save()
 
-        child_sitegroups = (
-            SiteGroup(name='Site Group 1A', slug='site-group-1a', parent=sitegroups[0]),
-            SiteGroup(name='Site Group 1B', slug='site-group-1b', parent=sitegroups[0]),
-            SiteGroup(name='Site Group 2A', slug='site-group-2a', parent=sitegroups[1]),
-            SiteGroup(name='Site Group 2B', slug='site-group-2b', parent=sitegroups[1]),
-            SiteGroup(name='Site Group 3A', slug='site-group-3a', parent=sitegroups[2]),
-            SiteGroup(name='Site Group 3B', slug='site-group-3b', parent=sitegroups[2]),
-        )
-        for sitegroup in child_sitegroups:
-            sitegroup.save()
+        groups = (
+            SiteGroup(name='Site Group 1A', slug='site-group-1a', parent=parent_groups[0]),
+            SiteGroup(name='Site Group 1B', slug='site-group-1b', parent=parent_groups[0]),
+            SiteGroup(name='Site Group 2A', slug='site-group-2a', parent=parent_groups[1]),
+            SiteGroup(name='Site Group 2B', slug='site-group-2b', parent=parent_groups[1]),
+            SiteGroup(name='Site Group 3A', slug='site-group-3a', parent=parent_groups[2]),
+            SiteGroup(name='Site Group 3B', slug='site-group-3b', parent=parent_groups[2]),
+        )
+        for site_group in groups:
+            site_group.save()
+
+        child_groups = (
+            SiteGroup(name='Site Group 1A1', slug='site-group-1a1', parent=groups[0]),
+            SiteGroup(name='Site Group 1B1', slug='site-group-1b1', parent=groups[1]),
+            SiteGroup(name='Site Group 2A1', slug='site-group-2a1', parent=groups[2]),
+            SiteGroup(name='Site Group 2B1', slug='site-group-2b1', parent=groups[3]),
+            SiteGroup(name='Site Group 3A1', slug='site-group-3a1', parent=groups[4]),
+            SiteGroup(name='Site Group 3B1', slug='site-group-3b1', parent=groups[5]),
+        )
+        for site_group in child_groups:
+            site_group.save()
 
     def test_q(self):
         params = {'q': 'foobar1'}
@@ -150,12 +179,19 @@ class SiteGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
     def test_parent(self):
-        parent_sitegroups = SiteGroup.objects.filter(parent__isnull=True)[:2]
-        params = {'parent_id': [parent_sitegroups[0].pk, parent_sitegroups[1].pk]}
+        site_groups = SiteGroup.objects.filter(parent__isnull=True)[:2]
+        params = {'parent_id': [site_groups[0].pk, site_groups[1].pk]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
-        params = {'parent': [parent_sitegroups[0].slug, parent_sitegroups[1].slug]}
+        params = {'parent': [site_groups[0].slug, site_groups[1].slug]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
 
+    def test_ancestor(self):
+        site_groups = SiteGroup.objects.filter(parent__isnull=True)[:2]
+        params = {'ancestor_id': [site_groups[0].pk, site_groups[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
+        params = {'ancestor': [site_groups[0].slug, site_groups[1].slug]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
+
 
 class SiteTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = Site.objects.all()
@@ -314,21 +350,29 @@ class LocationTestCase(TestCase, ChangeLoggedFilterSetTests):
         Site.objects.bulk_create(sites)
 
         parent_locations = (
-            Location(name='Parent Location 1', slug='parent-location-1', site=sites[0]),
-            Location(name='Parent Location 2', slug='parent-location-2', site=sites[1]),
-            Location(name='Parent Location 3', slug='parent-location-3', site=sites[2]),
+            Location(name='Location 1', slug='location-1', site=sites[0]),
+            Location(name='Location 2', slug='location-2', site=sites[1]),
+            Location(name='Location 3', slug='location-3', site=sites[2]),
         )
         for location in parent_locations:
             location.save()
 
         locations = (
-            Location(name='Location 1', slug='location-1', site=sites[0], parent=parent_locations[0], status=LocationStatusChoices.STATUS_PLANNED, description='foobar1'),
-            Location(name='Location 2', slug='location-2', site=sites[1], parent=parent_locations[1], status=LocationStatusChoices.STATUS_STAGING, description='foobar2'),
-            Location(name='Location 3', slug='location-3', site=sites[2], parent=parent_locations[2], status=LocationStatusChoices.STATUS_DECOMMISSIONING, description='foobar3'),
+            Location(name='Location 1A', slug='location-1a', site=sites[0], parent=parent_locations[0], status=LocationStatusChoices.STATUS_PLANNED, description='foobar1'),
+            Location(name='Location 2A', slug='location-2a', site=sites[1], parent=parent_locations[1], status=LocationStatusChoices.STATUS_STAGING, description='foobar2'),
+            Location(name='Location 3A', slug='location-3a', site=sites[2], parent=parent_locations[2], status=LocationStatusChoices.STATUS_DECOMMISSIONING, description='foobar3'),
         )
         for location in locations:
             location.save()
 
+        child_locations = (
+            Location(name='Location 1A1', slug='location-1a1', site=sites[0], parent=locations[0]),
+            Location(name='Location 2A1', slug='location-2a1', site=sites[1], parent=locations[1]),
+            Location(name='Location 3A1', slug='location-3a1', site=sites[2], parent=locations[2]),
+        )
+        for location in child_locations:
+            location.save()
+
     def test_q(self):
         params = {'q': 'foobar1'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -352,31 +396,38 @@ class LocationTestCase(TestCase, ChangeLoggedFilterSetTests):
     def test_region(self):
         regions = Region.objects.all()[:2]
         params = {'region_id': [regions[0].pk, regions[1].pk]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
         params = {'region': [regions[0].slug, regions[1].slug]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
 
     def test_site_group(self):
         site_groups = SiteGroup.objects.all()[:2]
         params = {'site_group_id': [site_groups[0].pk, site_groups[1].pk]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
         params = {'site_group': [site_groups[0].slug, site_groups[1].slug]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
 
     def test_site(self):
         sites = Site.objects.all()[:2]
         params = {'site_id': [sites[0].pk, sites[1].pk]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
         params = {'site': [sites[0].slug, sites[1].slug]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
 
     def test_parent(self):
-        parent_groups = Location.objects.filter(name__startswith='Parent')[:2]
-        params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]}
+        locations = Location.objects.filter(parent__isnull=True)[:2]
+        params = {'parent_id': [locations[0].pk, locations[1].pk]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
-        params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]}
+        params = {'parent': [locations[0].slug, locations[1].slug]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_ancestor(self):
+        locations = Location.objects.filter(parent__isnull=True)[:2]
+        params = {'ancestor_id': [locations[0].pk, locations[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        params = {'ancestor': [locations[0].slug, locations[1].slug]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+
 
 class RackRoleTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = RackRole.objects.all()

+ 28 - 2
netbox/tenancy/filtersets.py

@@ -26,12 +26,25 @@ __all__ = (
 class ContactGroupFilterSet(OrganizationalModelFilterSet):
     parent_id = django_filters.ModelMultipleChoiceFilter(
         queryset=ContactGroup.objects.all(),
-        label=_('Contact group (ID)'),
+        label=_('Parent contact group (ID)'),
     )
     parent = django_filters.ModelMultipleChoiceFilter(
         field_name='parent__slug',
         queryset=ContactGroup.objects.all(),
         to_field_name='slug',
+        label=_('Parent contact group (slug)'),
+    )
+    ancestor_id = TreeNodeMultipleChoiceFilter(
+        queryset=ContactGroup.objects.all(),
+        field_name='parent',
+        lookup_expr='in',
+        label=_('Contact group (ID)'),
+    )
+    ancestor = TreeNodeMultipleChoiceFilter(
+        queryset=ContactGroup.objects.all(),
+        field_name='parent',
+        lookup_expr='in',
+        to_field_name='slug',
         label=_('Contact group (slug)'),
     )
 
@@ -155,12 +168,25 @@ class ContactModelFilterSet(django_filters.FilterSet):
 class TenantGroupFilterSet(OrganizationalModelFilterSet):
     parent_id = django_filters.ModelMultipleChoiceFilter(
         queryset=TenantGroup.objects.all(),
-        label=_('Tenant group (ID)'),
+        label=_('Parent tenant group (ID)'),
     )
     parent = django_filters.ModelMultipleChoiceFilter(
         field_name='parent__slug',
         queryset=TenantGroup.objects.all(),
         to_field_name='slug',
+        label=_('Parent tenant group (slug)'),
+    )
+    ancestor_id = TreeNodeMultipleChoiceFilter(
+        queryset=TenantGroup.objects.all(),
+        field_name='parent',
+        lookup_expr='in',
+        label=_('Tenant group (ID)'),
+    )
+    ancestor = TreeNodeMultipleChoiceFilter(
+        queryset=TenantGroup.objects.all(),
+        field_name='parent',
+        lookup_expr='in',
+        to_field_name='slug',
         label=_('Tenant group (slug)'),
     )
 

+ 62 - 32
netbox/tenancy/tests/test_filtersets.py

@@ -15,35 +15,43 @@ class TenantGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
     def setUpTestData(cls):
 
         parent_tenant_groups = (
-            TenantGroup(name='Parent Tenant Group 1', slug='parent-tenant-group-1'),
-            TenantGroup(name='Parent Tenant Group 2', slug='parent-tenant-group-2'),
-            TenantGroup(name='Parent Tenant Group 3', slug='parent-tenant-group-3'),
+            TenantGroup(name='Tenant Group 1', slug='tenant-group-1'),
+            TenantGroup(name='Tenant Group 2', slug='tenant-group-2'),
+            TenantGroup(name='Tenant Group 3', slug='tenant-group-3'),
         )
-        for tenantgroup in parent_tenant_groups:
-            tenantgroup.save()
+        for tenant_group in parent_tenant_groups:
+            tenant_group.save()
 
         tenant_groups = (
             TenantGroup(
-                name='Tenant Group 1',
-                slug='tenant-group-1',
+                name='Tenant Group 1A',
+                slug='tenant-group-1a',
                 parent=parent_tenant_groups[0],
                 description='foobar1'
             ),
             TenantGroup(
-                name='Tenant Group 2',
-                slug='tenant-group-2',
+                name='Tenant Group 2A',
+                slug='tenant-group-2a',
                 parent=parent_tenant_groups[1],
                 description='foobar2'
             ),
             TenantGroup(
-                name='Tenant Group 3',
-                slug='tenant-group-3',
+                name='Tenant Group 3A',
+                slug='tenant-group-3a',
                 parent=parent_tenant_groups[2],
                 description='foobar3'
             ),
         )
-        for tenantgroup in tenant_groups:
-            tenantgroup.save()
+        for tenant_group in tenant_groups:
+            tenant_group.save()
+
+        child_tenant_groups = (
+            TenantGroup(name='Tenant Group 1A1', slug='tenant-group-1a1', parent=tenant_groups[0]),
+            TenantGroup(name='Tenant Group 2A1', slug='tenant-group-2a1', parent=tenant_groups[1]),
+            TenantGroup(name='Tenant Group 3A1', slug='tenant-group-3a1', parent=tenant_groups[2]),
+        )
+        for tenant_group in child_tenant_groups:
+            tenant_group.save()
 
     def test_q(self):
         params = {'q': 'foobar1'}
@@ -62,12 +70,19 @@ class TenantGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
     def test_parent(self):
-        parent_groups = TenantGroup.objects.filter(name__startswith='Parent')[:2]
-        params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]}
+        tenant_groups = TenantGroup.objects.filter(parent__isnull=True)[:2]
+        params = {'parent_id': [tenant_groups[0].pk, tenant_groups[1].pk]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
-        params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]}
+        params = {'parent': [tenant_groups[0].slug, tenant_groups[1].slug]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_ancestor(self):
+        tenant_groups = TenantGroup.objects.filter(parent__isnull=True)[:2]
+        params = {'ancestor_id': [tenant_groups[0].pk, tenant_groups[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        params = {'ancestor': [tenant_groups[0].slug, tenant_groups[1].slug]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+
 
 class TenantTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = Tenant.objects.all()
@@ -123,35 +138,43 @@ class ContactGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
     def setUpTestData(cls):
 
         parent_contact_groups = (
-            ContactGroup(name='Parent Contact Group 1', slug='parent-contact-group-1'),
-            ContactGroup(name='Parent Contact Group 2', slug='parent-contact-group-2'),
-            ContactGroup(name='Parent Contact Group 3', slug='parent-contact-group-3'),
+            ContactGroup(name='Contact Group 1', slug='contact-group-1'),
+            ContactGroup(name='Contact Group 2', slug='contact-group-2'),
+            ContactGroup(name='Contact Group 3', slug='contact-group-3'),
         )
-        for contactgroup in parent_contact_groups:
-            contactgroup.save()
+        for contact_group in parent_contact_groups:
+            contact_group.save()
 
         contact_groups = (
             ContactGroup(
-                name='Contact Group 1',
-                slug='contact-group-1',
+                name='Contact Group 1A',
+                slug='contact-group-1a',
                 parent=parent_contact_groups[0],
                 description='foobar1'
             ),
             ContactGroup(
-                name='Contact Group 2',
-                slug='contact-group-2',
+                name='Contact Group 2A',
+                slug='contact-group-2a',
                 parent=parent_contact_groups[1],
                 description='foobar2'
             ),
             ContactGroup(
-                name='Contact Group 3',
-                slug='contact-group-3',
+                name='Contact Group 3A',
+                slug='contact-group-3a',
                 parent=parent_contact_groups[2],
                 description='foobar3'
             ),
         )
-        for contactgroup in contact_groups:
-            contactgroup.save()
+        for contact_group in contact_groups:
+            contact_group.save()
+
+        child_contact_groups = (
+            ContactGroup(name='Contact Group 1A1', slug='contact-group-1a1', parent=contact_groups[0]),
+            ContactGroup(name='Contact Group 2A1', slug='contact-group-2a1', parent=contact_groups[1]),
+            ContactGroup(name='Contact Group 3A1', slug='contact-group-3a1', parent=contact_groups[2]),
+        )
+        for contact_group in child_contact_groups:
+            contact_group.save()
 
     def test_q(self):
         params = {'q': 'foobar1'}
@@ -170,12 +193,19 @@ class ContactGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
     def test_parent(self):
-        parent_groups = ContactGroup.objects.filter(parent__isnull=True)[:2]
-        params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]}
+        contact_groups = ContactGroup.objects.filter(parent__isnull=True)[:2]
+        params = {'parent_id': [contact_groups[0].pk, contact_groups[1].pk]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
-        params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]}
+        params = {'parent': [contact_groups[0].slug, contact_groups[1].slug]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_ancestor(self):
+        contact_groups = ContactGroup.objects.filter(parent__isnull=True)[:2]
+        params = {'ancestor_id': [contact_groups[0].pk, contact_groups[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        params = {'ancestor': [contact_groups[0].slug, contact_groups[1].slug]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+
 
 class ContactRoleTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = ContactRole.objects.all()

+ 11 - 0
netbox/wireless/filtersets.py

@@ -25,6 +25,17 @@ class WirelessLANGroupFilterSet(OrganizationalModelFilterSet):
         queryset=WirelessLANGroup.objects.all(),
         to_field_name='slug'
     )
+    ancestor_id = TreeNodeMultipleChoiceFilter(
+        queryset=WirelessLANGroup.objects.all(),
+        field_name='parent',
+        lookup_expr='in'
+    )
+    ancestor = TreeNodeMultipleChoiceFilter(
+        queryset=WirelessLANGroup.objects.all(),
+        field_name='parent',
+        lookup_expr='in',
+        to_field_name='slug'
+    )
 
     class Meta:
         model = WirelessLANGroup

+ 31 - 13
netbox/wireless/tests/test_filtersets.py

@@ -17,21 +17,32 @@ class WirelessLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
     @classmethod
     def setUpTestData(cls):
 
-        groups = (
+        parent_groups = (
             WirelessLANGroup(name='Wireless LAN Group 1', slug='wireless-lan-group-1', description='A'),
             WirelessLANGroup(name='Wireless LAN Group 2', slug='wireless-lan-group-2', description='B'),
             WirelessLANGroup(name='Wireless LAN Group 3', slug='wireless-lan-group-3', description='C'),
         )
+        for group in parent_groups:
+            group.save()
+
+        groups = (
+            WirelessLANGroup(name='Wireless LAN Group 1A', slug='wireless-lan-group-1a', parent=parent_groups[0], description='foobar1'),
+            WirelessLANGroup(name='Wireless LAN Group 1B', slug='wireless-lan-group-1b', parent=parent_groups[0], description='foobar2'),
+            WirelessLANGroup(name='Wireless LAN Group 2A', slug='wireless-lan-group-2a', parent=parent_groups[1]),
+            WirelessLANGroup(name='Wireless LAN Group 2B', slug='wireless-lan-group-2b', parent=parent_groups[1]),
+            WirelessLANGroup(name='Wireless LAN Group 3A', slug='wireless-lan-group-3a', parent=parent_groups[2]),
+            WirelessLANGroup(name='Wireless LAN Group 3B', slug='wireless-lan-group-3b', parent=parent_groups[2]),
+        )
         for group in groups:
             group.save()
 
         child_groups = (
-            WirelessLANGroup(name='Wireless LAN Group 1A', slug='wireless-lan-group-1a', parent=groups[0], description='foobar1'),
-            WirelessLANGroup(name='Wireless LAN Group 1B', slug='wireless-lan-group-1b', parent=groups[0], description='foobar2'),
-            WirelessLANGroup(name='Wireless LAN Group 2A', slug='wireless-lan-group-2a', parent=groups[1]),
-            WirelessLANGroup(name='Wireless LAN Group 2B', slug='wireless-lan-group-2b', parent=groups[1]),
-            WirelessLANGroup(name='Wireless LAN Group 3A', slug='wireless-lan-group-3a', parent=groups[2]),
-            WirelessLANGroup(name='Wireless LAN Group 3B', slug='wireless-lan-group-3b', parent=groups[2]),
+            WirelessLANGroup(name='Wireless LAN Group 1A1', slug='wireless-lan-group-1a1', parent=groups[0]),
+            WirelessLANGroup(name='Wireless LAN Group 1B1', slug='wireless-lan-group-1b1', parent=groups[1]),
+            WirelessLANGroup(name='Wireless LAN Group 2A1', slug='wireless-lan-group-2a1', parent=groups[2]),
+            WirelessLANGroup(name='Wireless LAN Group 2B1', slug='wireless-lan-group-2b1', parent=groups[3]),
+            WirelessLANGroup(name='Wireless LAN Group 3A1', slug='wireless-lan-group-3a1', parent=groups[4]),
+            WirelessLANGroup(name='Wireless LAN Group 3B1', slug='wireless-lan-group-3b1', parent=groups[5]),
         )
         for group in child_groups:
             group.save()
@@ -48,16 +59,23 @@ class WirelessLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'slug': ['wireless-lan-group-1', 'wireless-lan-group-2']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_description(self):
+        params = {'description': ['foobar1', 'foobar2']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
     def test_parent(self):
-        parent_groups = WirelessLANGroup.objects.filter(parent__isnull=True)[:2]
-        params = {'parent_id': [parent_groups[0].pk, parent_groups[1].pk]}
+        groups = WirelessLANGroup.objects.filter(parent__isnull=True)[:2]
+        params = {'parent_id': [groups[0].pk, groups[1].pk]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
-        params = {'parent': [parent_groups[0].slug, parent_groups[1].slug]}
+        params = {'parent': [groups[0].slug, groups[1].slug]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
 
-    def test_description(self):
-        params = {'description': ['foobar1', 'foobar2']}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+    def test_ancestor(self):
+        groups = WirelessLANGroup.objects.filter(parent__isnull=True)[:2]
+        params = {'ancestor_id': [groups[0].pk, groups[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
+        params = {'ancestor': [groups[0].slug, groups[1].slug]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
 
 
 class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests):