Просмотр исходного кода

Add missing filters for reverse many-to-many relationships

Jeremy Stretch 1 год назад
Родитель
Сommit
b36a70d236

+ 5 - 0
netbox/dcim/filtersets.py

@@ -1184,6 +1184,11 @@ class VirtualDeviceContextFilterSet(NetBoxModelFilterSet, TenancyFilterSet, Prim
         queryset=Device.objects.all(),
         label='Device model',
     )
+    interface_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='interfaces',
+        queryset=Interface.objects.all(),
+        label='Interface (ID)',
+    )
     status = django_filters.MultipleChoiceFilter(
         choices=VirtualDeviceContextStatusChoices
     )

+ 24 - 15
netbox/dcim/tests/test_filtersets.py

@@ -5409,15 +5409,22 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
         VirtualDeviceContext.objects.bulk_create(vdcs)
 
         interfaces = (
-            Interface(device=devices[0], name='Interface 1', type='virtual'),
-            Interface(device=devices[0], name='Interface 2', type='virtual'),
+            Interface(device=devices[0], name='Interface 1', type=InterfaceTypeChoices.TYPE_VIRTUAL),
+            Interface(device=devices[0], name='Interface 2', type=InterfaceTypeChoices.TYPE_VIRTUAL),
+            Interface(device=devices[1], name='Interface 3', type=InterfaceTypeChoices.TYPE_VIRTUAL),
+            Interface(device=devices[1], name='Interface 4', type=InterfaceTypeChoices.TYPE_VIRTUAL),
+            Interface(device=devices[2], name='Interface 5', type=InterfaceTypeChoices.TYPE_VIRTUAL),
+            Interface(device=devices[2], name='Interface 6', type=InterfaceTypeChoices.TYPE_VIRTUAL),
         )
         Interface.objects.bulk_create(interfaces)
-
         interfaces[0].vdcs.set([vdcs[0]])
         interfaces[1].vdcs.set([vdcs[1]])
+        interfaces[2].vdcs.set([vdcs[2]])
+        interfaces[3].vdcs.set([vdcs[3]])
+        interfaces[4].vdcs.set([vdcs[4]])
+        interfaces[5].vdcs.set([vdcs[5]])
 
-        addresses = (
+        ip_addresses = (
             IPAddress(assigned_object=interfaces[0], address='10.1.1.1/24'),
             IPAddress(assigned_object=interfaces[1], address='10.1.1.2/24'),
             IPAddress(assigned_object=None, address='10.1.1.3/24'),
@@ -5425,13 +5432,12 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
             IPAddress(assigned_object=interfaces[1], address='2001:db8::2/64'),
             IPAddress(assigned_object=None, address='2001:db8::3/64'),
         )
-        IPAddress.objects.bulk_create(addresses)
-
-        vdcs[0].primary_ip4 = addresses[0]
-        vdcs[0].primary_ip6 = addresses[3]
+        IPAddress.objects.bulk_create(ip_addresses)
+        vdcs[0].primary_ip4 = ip_addresses[0]
+        vdcs[0].primary_ip6 = ip_addresses[3]
         vdcs[0].save()
-        vdcs[1].primary_ip4 = addresses[1]
-        vdcs[1].primary_ip6 = addresses[4]
+        vdcs[1].primary_ip4 = ip_addresses[1]
+        vdcs[1].primary_ip6 = ip_addresses[4]
         vdcs[1].save()
 
     def test_q(self):
@@ -5439,8 +5445,11 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
 
     def test_device(self):
-        params = {'device': ['Device 1', 'Device 2']}
+        devices = Device.objects.filter(name__in=['Device 1', 'Device 2'])
+        params = {'device': [devices[0].name, devices[1].name]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
+        params = {'device_id': [devices[0].pk, devices[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
 
     def test_status(self):
         params = {'status': ['active']}
@@ -5450,10 +5459,10 @@ class VirtualDeviceContextTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'description': ['foobar1', 'foobar2']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
-    def test_device_id(self):
-        devices = Device.objects.filter(name__in=['Device 1', 'Device 2'])
-        params = {'device_id': [devices[0].pk, devices[1].pk]}
-        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+    def test_interface(self):
+        interfaces = Interface.objects.filter(name__in=['Interface 1', 'Interface 3'])
+        params = {'interface_id': [interfaces[0].pk, interfaces[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
     def test_has_primary_ip(self):
         params = {'has_primary_ip': True}

+ 87 - 1
netbox/extras/tests/test_filtersets.py

@@ -1128,7 +1128,93 @@ class ConfigTemplateTestCase(TestCase, ChangeLoggedFilterSetTests):
 class TagTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = Tag.objects.all()
     filterset = TagFilterSet
-    ignore_fields = ('object_types',)
+    ignore_fields = (
+        'object_types',
+
+        # Reverse relationships (to tagged models) we can ignore
+        'aggregate',
+        'asn',
+        'asnrange',
+        'cable',
+        'circuit',
+        'circuittermination',
+        'circuittype',
+        'cluster',
+        'clustergroup',
+        'clustertype',
+        'configtemplate',
+        'consoleport',
+        'consoleserverport',
+        'contact',
+        'contactassignment',
+        'contactgroup',
+        'contactrole',
+        'datasource',
+        'device',
+        'devicebay',
+        'devicerole',
+        'devicetype',
+        'dummymodel',  # From dummy_plugin
+        'eventrule',
+        'fhrpgroup',
+        'frontport',
+        'ikepolicy',
+        'ikeproposal',
+        'interface',
+        'inventoryitem',
+        'inventoryitemrole',
+        'ipaddress',
+        'iprange',
+        'ipsecpolicy',
+        'ipsecprofile',
+        'ipsecproposal',
+        'journalentry',
+        'l2vpn',
+        'l2vpntermination',
+        'location',
+        'manufacturer',
+        'module',
+        'modulebay',
+        'moduletype',
+        'platform',
+        'powerfeed',
+        'poweroutlet',
+        'powerpanel',
+        'powerport',
+        'prefix',
+        'provider',
+        'provideraccount',
+        'providernetwork',
+        'rack',
+        'rackreservation',
+        'rackrole',
+        'rearport',
+        'region',
+        'rir',
+        'role',
+        'routetarget',
+        'service',
+        'servicetemplate',
+        'site',
+        'sitegroup',
+        'tenant',
+        'tenantgroup',
+        'tunnel',
+        'tunnelgroup',
+        'tunneltermination',
+        'virtualchassis',
+        'virtualdevicecontext',
+        'virtualdisk',
+        'virtualmachine',
+        'vlan',
+        'vlangroup',
+        'vminterface',
+        'vrf',
+        'webhook',
+        'wirelesslan',
+        'wirelesslangroup',
+        'wirelesslink',
+    )
 
     @classmethod
     def setUpTestData(cls):

+ 39 - 0
netbox/ipam/filtersets.py

@@ -8,6 +8,7 @@ from drf_spectacular.types import OpenApiTypes
 from drf_spectacular.utils import extend_schema_field
 from netaddr.core import AddrFormatError
 
+from circuits.models import Provider
 from dcim.models import Device, Interface, Region, Site, SiteGroup
 from netbox.filtersets import ChangeLoggedModelFilterSet, OrganizationalModelFilterSet, NetBoxModelFilterSet
 from tenancy.filtersets import TenancyFilterSet
@@ -101,6 +102,28 @@ class RouteTargetFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
         to_field_name='rd',
         label=_('Export VRF (RD)'),
     )
+    importing_l2vpn_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='importing_l2vpns',
+        queryset=L2VPN.objects.all(),
+        label=_('Importing L2VPN'),
+    )
+    importing_l2vpn = django_filters.ModelMultipleChoiceFilter(
+        field_name='importing_l2vpns__identifier',
+        queryset=L2VPN.objects.all(),
+        to_field_name='identifier',
+        label=_('Importing L2VPN (identifier)'),
+    )
+    exporting_l2vpn_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='exporting_l2vpns',
+        queryset=L2VPN.objects.all(),
+        label=_('Exporting L2VPN'),
+    )
+    exporting_l2vpn = django_filters.ModelMultipleChoiceFilter(
+        field_name='exporting_l2vpns__identifier',
+        queryset=L2VPN.objects.all(),
+        to_field_name='identifier',
+        label=_('Exporting L2VPN (identifier)'),
+    )
 
     def search(self, queryset, name, value):
         if not value.strip():
@@ -214,6 +237,17 @@ class ASNFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
         to_field_name='slug',
         label=_('Site (slug)'),
     )
+    provider_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='providers',
+        queryset=Provider.objects.all(),
+        label=_('Provider (ID)'),
+    )
+    provider = django_filters.ModelMultipleChoiceFilter(
+        field_name='providers__slug',
+        queryset=Provider.objects.all(),
+        to_field_name='slug',
+        label=_('Provider (slug)'),
+    )
 
     class Meta:
         model = ASN
@@ -628,6 +662,11 @@ class IPAddressFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
     role = django_filters.MultipleChoiceFilter(
         choices=IPAddressRoleChoices
     )
+    service_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='services',
+        queryset=Service.objects.all(),
+        label=_('Service (ID)'),
+    )
 
     class Meta:
         model = IPAddress

+ 79 - 8
netbox/ipam/tests/test_filtersets.py

@@ -2,6 +2,7 @@ from django.contrib.contenttypes.models import ContentType
 from django.test import TestCase
 from netaddr import IPNetwork
 
+from circuits.models import Provider
 from dcim.choices import InterfaceTypeChoices
 from dcim.models import Device, DeviceRole, DeviceType, Interface, Location, Manufacturer, Rack, Region, Site, SiteGroup
 from ipam.choices import *
@@ -10,6 +11,8 @@ from ipam.models import *
 from tenancy.models import Tenant, TenantGroup
 from utilities.testing import ChangeLoggedFilterSetTests, create_test_device, create_test_virtualmachine
 from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface
+from vpn.choices import L2VPNTypeChoices
+from vpn.models import L2VPN
 
 
 class ASNRangeTestCase(TestCase, ChangeLoggedFilterSetTests):
@@ -110,13 +113,6 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
         ]
         RIR.objects.bulk_create(rirs)
 
-        sites = [
-            Site(name='Site 1', slug='site-1'),
-            Site(name='Site 2', slug='site-2'),
-            Site(name='Site 3', slug='site-3')
-        ]
-        Site.objects.bulk_create(sites)
-
         tenants = [
             Tenant(name='Tenant 1', slug='tenant-1'),
             Tenant(name='Tenant 2', slug='tenant-2'),
@@ -136,6 +132,12 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
         )
         ASN.objects.bulk_create(asns)
 
+        sites = [
+            Site(name='Site 1', slug='site-1'),
+            Site(name='Site 2', slug='site-2'),
+            Site(name='Site 3', slug='site-3')
+        ]
+        Site.objects.bulk_create(sites)
         asns[0].sites.set([sites[0]])
         asns[1].sites.set([sites[1]])
         asns[2].sites.set([sites[2]])
@@ -143,6 +145,16 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
         asns[4].sites.set([sites[1]])
         asns[5].sites.set([sites[2]])
 
+        providers = (
+            Provider(name='Provider 1', slug='provider-1'),
+            Provider(name='Provider 2', slug='provider-2'),
+            Provider(name='Provider 3', slug='provider-3'),
+        )
+        Provider.objects.bulk_create(providers)
+        providers[0].asns.add(asns[0])
+        providers[1].asns.add(asns[1])
+        providers[2].asns.add(asns[2])
+
     def test_q(self):
         params = {'q': 'foobar1'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -176,6 +188,11 @@ class ASNTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'description': ['foobar1', 'foobar2']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_provider(self):
+        providers = Provider.objects.all()[:2]
+        params = {'provider_id': [providers[0].pk, providers[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
 
 class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = VRF.objects.all()
@@ -188,7 +205,7 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
             return 'import_target'
         if field.name == 'export_targets':
             return 'export_target'
-        return super().get_m2m_filter_name(field)
+        return ChangeLoggedFilterSetTests.get_m2m_filter_name(field)
 
     @classmethod
     def setUpTestData(cls):
@@ -286,6 +303,19 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = RouteTarget.objects.all()
     filterset = RouteTargetFilterSet
 
+    @staticmethod
+    def get_m2m_filter_name(field):
+        # Override filter names for import & export VRFs and L2VPNs
+        if field.name == 'importing_vrfs':
+            return 'importing_vrf'
+        if field.name == 'exporting_vrfs':
+            return 'exporting_vrf'
+        if field.name == 'importing_l2vpns':
+            return 'importing_l2vpn'
+        if field.name == 'exporting_l2vpns':
+            return 'exporting_l2vpn'
+        return ChangeLoggedFilterSetTests.get_m2m_filter_name(field)
+
     @classmethod
     def setUpTestData(cls):
 
@@ -331,6 +361,17 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
         vrfs[1].import_targets.add(route_targets[4], route_targets[5])
         vrfs[1].export_targets.add(route_targets[6], route_targets[7])
 
+        l2vpns = (
+            L2VPN(name='L2VPN 1', slug='l2vpn-1', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=100),
+            L2VPN(name='L2VPN 2', slug='l2vpn-2', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=200),
+            L2VPN(name='L2VPN 3', slug='l2vpn-3', type=L2VPNTypeChoices.TYPE_VXLAN, identifier=300),
+        )
+        L2VPN.objects.bulk_create(l2vpns)
+        l2vpns[0].import_targets.add(route_targets[0], route_targets[1])
+        l2vpns[0].export_targets.add(route_targets[2], route_targets[3])
+        l2vpns[1].import_targets.add(route_targets[4], route_targets[5])
+        l2vpns[1].export_targets.add(route_targets[6], route_targets[7])
+
     def test_q(self):
         params = {'q': 'foobar1'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -353,6 +394,20 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'exporting_vrf': [vrfs[0].rd, vrfs[1].rd]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
 
+    def test_importing_l2vpn(self):
+        l2vpns = L2VPN.objects.all()[:2]
+        params = {'importing_l2vpn_id': [l2vpns[0].pk, l2vpns[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        params = {'importing_l2vpn': [l2vpns[0].identifier, l2vpns[1].identifier]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+
+    def test_exporting_l2vpn(self):
+        l2vpns = L2VPN.objects.all()[:2]
+        params = {'exporting_l2vpn_id': [l2vpns[0].pk, l2vpns[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        params = {'exporting_l2vpn': [l2vpns[0].identifier, l2vpns[1].identifier]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+
     def test_tenant(self):
         tenants = Tenant.objects.all()[:2]
         params = {'tenant_id': [tenants[0].pk, tenants[1].pk]}
@@ -1102,6 +1157,16 @@ class IPAddressTestCase(TestCase, ChangeLoggedFilterSetTests):
         )
         IPAddress.objects.bulk_create(ipaddresses)
 
+        services = (
+            Service(name='Service 1', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
+            Service(name='Service 2', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
+            Service(name='Service 3', protocol=ServiceProtocolChoices.PROTOCOL_TCP, ports=[1]),
+        )
+        Service.objects.bulk_create(services)
+        services[0].ipaddresses.add(ipaddresses[0])
+        services[1].ipaddresses.add(ipaddresses[1])
+        services[2].ipaddresses.add(ipaddresses[2])
+
     def test_q(self):
         params = {'q': 'foobar1'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -1241,6 +1306,11 @@ class IPAddressTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'tenant_group': [tenant_groups[0].slug, tenant_groups[1].slug]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
 
+    def test_service(self):
+        services = Service.objects.all()[:2]
+        params = {'service_id': [services[0].pk, services[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
 
 class FHRPGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = FHRPGroup.objects.all()
@@ -1485,6 +1555,7 @@ class VLANGroupTestCase(TestCase, ChangeLoggedFilterSetTests):
 class VLANTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = VLAN.objects.all()
     filterset = VLANFilterSet
+    ignore_fields = ('interfaces_as_tagged', 'vminterfaces_as_tagged')
 
     @classmethod
     def setUpTestData(cls):

+ 15 - 0
netbox/users/filtersets.py

@@ -20,6 +20,16 @@ class GroupFilterSet(BaseFilterSet):
         method='search',
         label=_('Search'),
     )
+    user_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='user',
+        queryset=get_user_model().objects.all(),
+        label=_('User (ID)'),
+    )
+    permission_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='object_permissions',
+        queryset=ObjectPermission.objects.all(),
+        label=_('Permission (ID)'),
+    )
 
     class Meta:
         model = Group
@@ -47,6 +57,11 @@ class UserFilterSet(BaseFilterSet):
         to_field_name='name',
         label=_('Group (name)'),
     )
+    permission_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='object_permissions',
+        queryset=ObjectPermission.objects.all(),
+        label=_('Permission (ID)'),
+    )
 
     class Meta:
         model = get_user_model()

+ 45 - 0
netbox/users/tests/test_filtersets.py

@@ -67,6 +67,16 @@ class UserTestCase(TestCase, BaseFilterSetTests):
         users[1].groups.set([groups[1]])
         users[2].groups.set([groups[2]])
 
+        object_permissions = (
+            ObjectPermission(name='Permission 1', actions=['add']),
+            ObjectPermission(name='Permission 2', actions=['change']),
+            ObjectPermission(name='Permission 3', actions=['delete']),
+        )
+        ObjectPermission.objects.bulk_create(object_permissions)
+        object_permissions[0].users.add(users[0])
+        object_permissions[1].users.add(users[1])
+        object_permissions[2].users.add(users[2])
+
     def test_q(self):
         params = {'q': 'user1'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -106,6 +116,11 @@ class UserTestCase(TestCase, BaseFilterSetTests):
         params = {'group': [groups[0].name, groups[1].name]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_permission(self):
+        object_permissions = ObjectPermission.objects.all()[:2]
+        params = {'permission_id': [object_permissions[0].pk, object_permissions[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
 
 class GroupTestCase(TestCase, BaseFilterSetTests):
     queryset = Group.objects.all()
@@ -122,6 +137,26 @@ class GroupTestCase(TestCase, BaseFilterSetTests):
         )
         Group.objects.bulk_create(groups)
 
+        users = (
+            User(username='User 1'),
+            User(username='User 2'),
+            User(username='User 3'),
+        )
+        User.objects.bulk_create(users)
+        users[0].groups.set([groups[0]])
+        users[1].groups.set([groups[1]])
+        users[2].groups.set([groups[2]])
+
+        object_permissions = (
+            ObjectPermission(name='Permission 1', actions=['add']),
+            ObjectPermission(name='Permission 2', actions=['change']),
+            ObjectPermission(name='Permission 3', actions=['delete']),
+        )
+        ObjectPermission.objects.bulk_create(object_permissions)
+        object_permissions[0].groups.add(groups[0])
+        object_permissions[1].groups.add(groups[1])
+        object_permissions[2].groups.add(groups[2])
+
     def test_q(self):
         params = {'q': 'group 1'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -130,6 +165,16 @@ class GroupTestCase(TestCase, BaseFilterSetTests):
         params = {'name': ['Group 1', 'Group 2']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_user(self):
+        users = User.objects.all()[:2]
+        params = {'user_id': [users[0].pk, users[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+    def test_permission(self):
+        object_permissions = ObjectPermission.objects.all()[:2]
+        params = {'permission_id': [object_permissions[0].pk, object_permissions[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
 
 class ObjectPermissionTestCase(TestCase, BaseFilterSetTests):
     queryset = ObjectPermission.objects.all()

+ 1 - 5
netbox/utilities/testing/filtersets.py

@@ -109,7 +109,7 @@ class BaseFilterSetTests:
                     f'No filter defined for {filter_name} ({model_field.name})!'
                 )
 
-            elif type(model_field) is ManyToManyField:
+            elif type(model_field) in (ManyToManyField, ManyToManyRel):
                 filter_name = self.get_m2m_filter_name(model_field)
                 filter_name = f'{filter_name}_id'
                 self.assertIn(
@@ -118,10 +118,6 @@ class BaseFilterSetTests:
                     f'No filter defined for {filter_name} ({model_field.name})!'
                 )
 
-            # TODO: Many-to-many relationships
-            elif type(model_field) is ManyToManyRel:
-                continue
-
             # TODO: Generic relationships
             elif type(model_field) in (GenericForeignKey, GenericRelation):
                 continue

+ 22 - 0
netbox/vpn/filtersets.py

@@ -124,6 +124,17 @@ class TunnelTerminationFilterSet(NetBoxModelFilterSet):
 
 
 class IKEProposalFilterSet(NetBoxModelFilterSet):
+    ike_policy_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='ike_policies',
+        queryset=IKEPolicy.objects.all(),
+        label=_('IKE policy (ID)'),
+    )
+    ike_policy = django_filters.ModelMultipleChoiceFilter(
+        field_name='ike_policies__name',
+        queryset=IKEPolicy.objects.all(),
+        to_field_name='name',
+        label=_('IKE policy (name)'),
+    )
     authentication_method = django_filters.MultipleChoiceFilter(
         choices=AuthenticationMethodChoices
     )
@@ -184,6 +195,17 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet):
 
 
 class IPSecProposalFilterSet(NetBoxModelFilterSet):
+    ipsec_policy_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='ipsec_policies',
+        queryset=IPSecPolicy.objects.all(),
+        label=_('IPSec policy (ID)'),
+    )
+    ipsec_policy = django_filters.ModelMultipleChoiceFilter(
+        field_name='ipsec_policies__name',
+        queryset=IPSecPolicy.objects.all(),
+        to_field_name='name',
+        label=_('IPSec policy (name)'),
+    )
     encryption_algorithm = django_filters.MultipleChoiceFilter(
         choices=EncryptionAlgorithmChoices
     )

+ 35 - 1
netbox/vpn/tests/test_filtersets.py

@@ -330,6 +330,16 @@ class IKEProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
         )
         IKEProposal.objects.bulk_create(ike_proposals)
 
+        ike_policies = (
+            IKEPolicy(name='IKE Policy 1'),
+            IKEPolicy(name='IKE Policy 2'),
+            IKEPolicy(name='IKE Policy 3'),
+        )
+        IKEPolicy.objects.bulk_create(ike_policies)
+        ike_policies[0].proposals.add(ike_proposals[0])
+        ike_policies[1].proposals.add(ike_proposals[1])
+        ike_policies[2].proposals.add(ike_proposals[2])
+
     def test_q(self):
         params = {'q': 'foobar1'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -342,6 +352,13 @@ class IKEProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'description': ['foobar1', 'foobar2']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_ike_policy(self):
+        ike_policies = IKEPolicy.objects.all()[:2]
+        params = {'ike_policy_id': [ike_policies[0].pk, ike_policies[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'ike_policy': [ike_policies[0].name, ike_policies[1].name]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
     def test_authentication_method(self):
         params = {'authentication_method': [
             AuthenticationMethodChoices.PRESHARED_KEYS, AuthenticationMethodChoices.CERTIFICATES
@@ -487,6 +504,16 @@ class IPSecProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
         )
         IPSecProposal.objects.bulk_create(ipsec_proposals)
 
+        ipsec_policies = (
+            IPSecPolicy(name='IPSec Policy 1'),
+            IPSecPolicy(name='IPSec Policy 2'),
+            IPSecPolicy(name='IPSec Policy 3'),
+        )
+        IPSecPolicy.objects.bulk_create(ipsec_policies)
+        ipsec_policies[0].proposals.add(ipsec_proposals[0])
+        ipsec_policies[1].proposals.add(ipsec_proposals[1])
+        ipsec_policies[2].proposals.add(ipsec_proposals[2])
+
     def test_q(self):
         params = {'q': 'foobar1'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -499,6 +526,13 @@ class IPSecProposalTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'description': ['foobar1', 'foobar2']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_ipsec_policy(self):
+        ipsec_policies = IPSecPolicy.objects.all()[:2]
+        params = {'ipsec_policy_id': [ipsec_policies[0].pk, ipsec_policies[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'ipsec_policy': [ipsec_policies[0].name, ipsec_policies[1].name]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
     def test_encryption_algorithm(self):
         params = {'encryption_algorithm': [
             EncryptionAlgorithmChoices.ENCRYPTION_AES128_CBC, EncryptionAlgorithmChoices.ENCRYPTION_AES192_CBC
@@ -716,7 +750,7 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests):
             return 'import_target'
         if field.name == 'export_targets':
             return 'export_target'
-        return super().get_m2m_filter_name(field)
+        return ChangeLoggedFilterSetTests.get_m2m_filter_name(field)
 
     @classmethod
     def setUpTestData(cls):

+ 5 - 0
netbox/wireless/filtersets.py

@@ -2,6 +2,7 @@ import django_filters
 from django.db.models import Q
 
 from dcim.choices import LinkStatusChoices
+from dcim.models import Interface
 from ipam.models import VLAN
 from netbox.filtersets import OrganizationalModelFilterSet, NetBoxModelFilterSet
 from tenancy.filtersets import TenancyFilterSet
@@ -60,6 +61,10 @@ class WirelessLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
     vlan_id = django_filters.ModelMultipleChoiceFilter(
         queryset=VLAN.objects.all()
     )
+    interface_id = django_filters.ModelMultipleChoiceFilter(
+        queryset=Interface.objects.all(),
+        field_name='interfaces'
+    )
     auth_type = django_filters.MultipleChoiceFilter(
         choices=WirelessAuthTypeChoices
     )

+ 16 - 0
netbox/wireless/tests/test_filtersets.py

@@ -153,6 +153,17 @@ class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests):
         )
         WirelessLAN.objects.bulk_create(wireless_lans)
 
+        device = create_test_device('Device 1')
+        interfaces = (
+            Interface(device=device, name='Interface 1', type=InterfaceTypeChoices.TYPE_80211N),
+            Interface(device=device, name='Interface 2', type=InterfaceTypeChoices.TYPE_80211N),
+            Interface(device=device, name='Interface 3', type=InterfaceTypeChoices.TYPE_80211N),
+        )
+        Interface.objects.bulk_create(interfaces)
+        interfaces[0].wireless_lans.add(wireless_lans[0])
+        interfaces[1].wireless_lans.add(wireless_lans[1])
+        interfaces[2].wireless_lans.add(wireless_lans[2])
+
     def test_q(self):
         params = {'q': 'foobar1'}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -200,6 +211,11 @@ class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'tenant': [tenants[0].slug, tenants[1].slug]}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
+    def test_interface(self):
+        interfaces = Interface.objects.all()[:2]
+        params = {'interface_id': [interfaces[0].pk, interfaces[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
 
 class WirelessLinkTestCase(TestCase, ChangeLoggedFilterSetTests):
     queryset = WirelessLink.objects.all()