Jeremy Stretch 1 dzień temu
rodzic
commit
55daf4c52f
2 zmienionych plików z 120 dodań i 10 usunięć
  1. 110 1
      netbox/dcim/tests/test_forms.py
  2. 10 9
      netbox/netbox/forms/model_forms.py

+ 110 - 1
netbox/dcim/tests/test_forms.py

@@ -10,7 +10,8 @@ from dcim.choices import (
 )
 from dcim.forms import *
 from dcim.models import *
-from ipam.models import VLAN
+from ipam.models import ASN, RIR, VLAN
+from utilities.forms.rendering import M2MAddRemoveFields
 from utilities.testing import create_test_device
 from virtualization.models import Cluster, ClusterGroup, ClusterType
 
@@ -417,3 +418,111 @@ class InterfaceTestCase(TestCase):
         self.assertNotIn('untagged_vlan', form.cleaned_data.keys())
         self.assertNotIn('tagged_vlans', form.cleaned_data.keys())
         self.assertNotIn('qinq_svlan', form.cleaned_data.keys())
+
+
+class SiteFormTestCase(TestCase):
+    """
+    Tests for M2MAddRemoveFields using Site ASN assignments as the test case.
+    Covers both simple mode (single multi-select field) and add/remove mode (dual fields).
+    """
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.rir = RIR.objects.create(name='RIR 1', slug='rir-1')
+        # Create 110 ASNs: 100 to pre-assign (triggering add/remove mode) plus 10 extras
+        ASN.objects.bulk_create([ASN(asn=i, rir=cls.rir) for i in range(1, 111)])
+        cls.asns = list(ASN.objects.order_by('asn'))
+
+    def _site_data(self, **kwargs):
+        data = {'name': 'Test Site', 'slug': 'test-site', 'status': 'active'}
+        data.update(kwargs)
+        return data
+
+    def test_new_site_uses_simple_mode(self):
+        """A form for a new site uses the single 'asns' field (simple mode)."""
+        form = SiteForm(data=self._site_data())
+        self.assertIn('asns', form.fields)
+        self.assertNotIn('add_asns', form.fields)
+        self.assertNotIn('remove_asns', form.fields)
+
+    def test_existing_site_below_threshold_uses_simple_mode(self):
+        """A form for an existing site with fewer than THRESHOLD ASNs uses simple mode."""
+        site = Site.objects.create(name='Site 1', slug='site-1')
+        site.asns.set(self.asns[:5])
+        form = SiteForm(instance=site)
+        self.assertIn('asns', form.fields)
+        self.assertNotIn('add_asns', form.fields)
+        self.assertNotIn('remove_asns', form.fields)
+
+    def test_existing_site_at_threshold_uses_add_remove_mode(self):
+        """A form for an existing site with THRESHOLD or more ASNs uses add/remove mode."""
+        site = Site.objects.create(name='Site 2', slug='site-2')
+        site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD])
+        form = SiteForm(instance=site)
+        self.assertNotIn('asns', form.fields)
+        self.assertIn('add_asns', form.fields)
+        self.assertIn('remove_asns', form.fields)
+
+    def test_simple_mode_assigns_asns_on_create(self):
+        """Saving a new site via simple mode assigns the selected ASNs."""
+        asn_pks = [asn.pk for asn in self.asns[:3]]
+        form = SiteForm(data=self._site_data(asns=asn_pks))
+        self.assertTrue(form.is_valid(), form.errors)
+        site = form.save()
+        self.assertEqual(set(site.asns.values_list('pk', flat=True)), set(asn_pks))
+
+    def test_simple_mode_replaces_asns_on_edit(self):
+        """Saving an existing site via simple mode replaces the current ASN assignments."""
+        site = Site.objects.create(name='Site 3', slug='site-3')
+        site.asns.set(self.asns[:3])
+        new_asn_pks = [asn.pk for asn in self.asns[3:6]]
+        form = SiteForm(
+            data=self._site_data(name='Site 3', slug='site-3', asns=new_asn_pks),
+            instance=site
+        )
+        self.assertTrue(form.is_valid(), form.errors)
+        site = form.save()
+        self.assertEqual(set(site.asns.values_list('pk', flat=True)), set(new_asn_pks))
+
+    def test_add_remove_mode_adds_asns(self):
+        """In add/remove mode, specifying 'add_asns' appends to current assignments."""
+        site = Site.objects.create(name='Site 4', slug='site-4')
+        site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD])
+        new_asn_pks = [asn.pk for asn in self.asns[M2MAddRemoveFields.THRESHOLD:]]
+        form = SiteForm(
+            data=self._site_data(name='Site 4', slug='site-4', add_asns=new_asn_pks),
+            instance=site
+        )
+        self.assertTrue(form.is_valid(), form.errors)
+        site = form.save()
+        self.assertEqual(site.asns.count(), len(self.asns))
+
+    def test_add_remove_mode_removes_asns(self):
+        """In add/remove mode, specifying 'remove_asns' drops those assignments."""
+        site = Site.objects.create(name='Site 5', slug='site-5')
+        site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD])
+        remove_pks = [asn.pk for asn in self.asns[:5]]
+        form = SiteForm(
+            data=self._site_data(name='Site 5', slug='site-5', remove_asns=remove_pks),
+            instance=site
+        )
+        self.assertTrue(form.is_valid(), form.errors)
+        site = form.save()
+        self.assertEqual(site.asns.count(), M2MAddRemoveFields.THRESHOLD - 5)
+        self.assertFalse(site.asns.filter(pk__in=remove_pks).exists())
+
+    def test_add_remove_mode_simultaneous_add_and_remove(self):
+        """In add/remove mode, add and remove operations are applied together."""
+        site = Site.objects.create(name='Site 6', slug='site-6')
+        site.asns.set(self.asns[:M2MAddRemoveFields.THRESHOLD])
+        add_pks = [asn.pk for asn in self.asns[M2MAddRemoveFields.THRESHOLD:M2MAddRemoveFields.THRESHOLD + 3]]
+        remove_pks = [asn.pk for asn in self.asns[:3]]
+        form = SiteForm(
+            data=self._site_data(name='Site 6', slug='site-6', add_asns=add_pks, remove_asns=remove_pks),
+            instance=site
+        )
+        self.assertTrue(form.is_valid(), form.errors)
+        site = form.save()
+        self.assertEqual(site.asns.count(), M2MAddRemoveFields.THRESHOLD)
+        self.assertTrue(site.asns.filter(pk__in=add_pks).count() == 3)
+        self.assertFalse(site.asns.filter(pk__in=remove_pks).exists())

+ 10 - 9
netbox/netbox/forms/model_forms.py

@@ -2,7 +2,6 @@ import json
 
 from django import forms
 from django.contrib.contenttypes.models import ContentType
-from django.db import models
 from django.db.models.fields.related import ManyToManyRel
 
 from extras.choices import *
@@ -77,15 +76,17 @@ class NetBoxModelForm(
         and add/remove (dual field) modes.
         """
         self.instance._m2m_values = {}
-        for field in self.instance._meta.get_fields():
-            # Determine the accessor name for this M2M relationship
-            if isinstance(field, models.ManyToManyField):
-                name = field.name
-            elif isinstance(field, ManyToManyRel):
-                name = field.get_accessor_name()
-            else:
-                continue
 
+        # Collect names to process: local M2M fields (includes TaggableManager from django-taggit)
+        # plus reverse M2M relations (ManyToManyRel).
+        names = [field.name for field in self.instance._meta.local_many_to_many]
+        names += [
+            field.get_accessor_name()
+            for field in self.instance._meta.get_fields()
+            if isinstance(field, ManyToManyRel)
+        ]
+
+        for name in names:
             if name in self.cleaned_data:
                 # Simple mode: single multi-select field
                 self.instance._m2m_values[name] = list(self.cleaned_data[name])