Просмотр исходного кода

Fixes #14322: Populate default custom field values when instantiating templated device components

Jeremy Stretch 2 лет назад
Родитель
Сommit
32264ac3e3
3 измененных файлов с 88 добавлено и 35 удалено
  1. 12 7
      netbox/dcim/models/devices.py
  2. 67 28
      netbox/dcim/tests/test_models.py
  3. 9 0
      netbox/extras/models/customfields.py

+ 12 - 7
netbox/dcim/models/devices.py

@@ -16,7 +16,7 @@ from django.utils.translation import gettext_lazy as _
 
 
 from dcim.choices import *
 from dcim.choices import *
 from dcim.constants import *
 from dcim.constants import *
-from extras.models import ConfigContextModel
+from extras.models import ConfigContextModel, CustomField
 from extras.querysets import ConfigContextModelQuerySet
 from extras.querysets import ConfigContextModelQuerySet
 from netbox.config import ConfigItem
 from netbox.config import ConfigItem
 from netbox.models import OrganizationalModel, PrimaryModel
 from netbox.models import OrganizationalModel, PrimaryModel
@@ -985,11 +985,17 @@ class Device(
             bulk_create: If True, bulk_create() will be called to create all components in a single query
             bulk_create: If True, bulk_create() will be called to create all components in a single query
                          (default). Otherwise, save() will be called on each instance individually.
                          (default). Otherwise, save() will be called on each instance individually.
         """
         """
+        components = [obj.instantiate(device=self) for obj in queryset]
+        if not components:
+            return
+
+        # Set default values for any applicable custom fields
+        model = queryset.model.component_model
+        if cf_defaults := CustomField.objects.get_defaults_for_model(model):
+            for component in components:
+                component.custom_field_data = cf_defaults
+
         if bulk_create:
         if bulk_create:
-            components = [obj.instantiate(device=self) for obj in queryset]
-            if not components:
-                return
-            model = components[0]._meta.model
             model.objects.bulk_create(components)
             model.objects.bulk_create(components)
             # Manually send the post_save signal for each of the newly created components
             # Manually send the post_save signal for each of the newly created components
             for component in components:
             for component in components:
@@ -1002,8 +1008,7 @@ class Device(
                     update_fields=None
                     update_fields=None
                 )
                 )
         else:
         else:
-            for obj in queryset:
-                component = obj.instantiate(device=self)
+            for component in components:
                 component.save()
                 component.save()
 
 
     def save(self, *args, **kwargs):
     def save(self, *args, **kwargs):

+ 67 - 28
netbox/dcim/tests/test_models.py

@@ -1,9 +1,11 @@
+from django.contrib.contenttypes.models import ContentType
 from django.core.exceptions import ValidationError
 from django.core.exceptions import ValidationError
 from django.test import TestCase
 from django.test import TestCase
 
 
 from circuits.models import *
 from circuits.models import *
 from dcim.choices import *
 from dcim.choices import *
 from dcim.models import *
 from dcim.models import *
+from extras.models import CustomField
 from tenancy.models import Tenant
 from tenancy.models import Tenant
 from utilities.utils import drange
 from utilities.utils import drange
 
 
@@ -255,6 +257,23 @@ class DeviceTestCase(TestCase):
         )
         )
         DeviceRole.objects.bulk_create(roles)
         DeviceRole.objects.bulk_create(roles)
 
 
+        # Create a CustomField with a default value & assign it to all component models
+        cf1 = CustomField.objects.create(name='cf1', default='foo')
+        cf1.content_types.set(
+            ContentType.objects.filter(app_label='dcim', model__in=[
+                'consoleport',
+                'consoleserverport',
+                'powerport',
+                'poweroutlet',
+                'interface',
+                'rearport',
+                'frontport',
+                'modulebay',
+                'devicebay',
+                'inventoryitem',
+            ])
+        )
+
         # Create DeviceType components
         # Create DeviceType components
         ConsolePortTemplate(
         ConsolePortTemplate(
             device_type=device_type,
             device_type=device_type,
@@ -266,18 +285,18 @@ class DeviceTestCase(TestCase):
             name='Console Server Port 1'
             name='Console Server Port 1'
         ).save()
         ).save()
 
 
-        ppt = PowerPortTemplate(
+        powerport = PowerPortTemplate(
             device_type=device_type,
             device_type=device_type,
             name='Power Port 1',
             name='Power Port 1',
             maximum_draw=1000,
             maximum_draw=1000,
             allocated_draw=500
             allocated_draw=500
         )
         )
-        ppt.save()
+        powerport.save()
 
 
         PowerOutletTemplate(
         PowerOutletTemplate(
             device_type=device_type,
             device_type=device_type,
             name='Power Outlet 1',
             name='Power Outlet 1',
-            power_port=ppt,
+            power_port=powerport,
             feed_leg=PowerOutletFeedLegChoices.FEED_LEG_A
             feed_leg=PowerOutletFeedLegChoices.FEED_LEG_A
         ).save()
         ).save()
 
 
@@ -288,19 +307,19 @@ class DeviceTestCase(TestCase):
             mgmt_only=True
             mgmt_only=True
         ).save()
         ).save()
 
 
-        rpt = RearPortTemplate(
+        rearport = RearPortTemplate(
             device_type=device_type,
             device_type=device_type,
             name='Rear Port 1',
             name='Rear Port 1',
             type=PortTypeChoices.TYPE_8P8C,
             type=PortTypeChoices.TYPE_8P8C,
             positions=8
             positions=8
         )
         )
-        rpt.save()
+        rearport.save()
 
 
         FrontPortTemplate(
         FrontPortTemplate(
             device_type=device_type,
             device_type=device_type,
             name='Front Port 1',
             name='Front Port 1',
             type=PortTypeChoices.TYPE_8P8C,
             type=PortTypeChoices.TYPE_8P8C,
-            rear_port=rpt,
+            rear_port=rearport,
             rear_port_position=2
             rear_port_position=2
         ).save()
         ).save()
 
 
@@ -314,73 +333,93 @@ class DeviceTestCase(TestCase):
             name='Device Bay 1'
             name='Device Bay 1'
         ).save()
         ).save()
 
 
+        InventoryItemTemplate(
+            device_type=device_type,
+            name='Inventory Item 1'
+        ).save()
+
     def test_device_creation(self):
     def test_device_creation(self):
         """
         """
         Ensure that all Device components are copied automatically from the DeviceType.
         Ensure that all Device components are copied automatically from the DeviceType.
         """
         """
-        d = Device(
+        device = Device(
             site=Site.objects.first(),
             site=Site.objects.first(),
             device_type=DeviceType.objects.first(),
             device_type=DeviceType.objects.first(),
             role=DeviceRole.objects.first(),
             role=DeviceRole.objects.first(),
             name='Test Device 1'
             name='Test Device 1'
         )
         )
-        d.save()
+        device.save()
 
 
-        ConsolePort.objects.get(
-            device=d,
+        consoleport = ConsolePort.objects.get(
+            device=device,
             name='Console Port 1'
             name='Console Port 1'
         )
         )
+        self.assertEqual(consoleport.cf['cf1'], 'foo')
 
 
-        ConsoleServerPort.objects.get(
-            device=d,
+        consoleserverport = ConsoleServerPort.objects.get(
+            device=device,
             name='Console Server Port 1'
             name='Console Server Port 1'
         )
         )
+        self.assertEqual(consoleserverport.cf['cf1'], 'foo')
 
 
-        pp = PowerPort.objects.get(
-            device=d,
+        powerport = PowerPort.objects.get(
+            device=device,
             name='Power Port 1',
             name='Power Port 1',
             maximum_draw=1000,
             maximum_draw=1000,
             allocated_draw=500
             allocated_draw=500
         )
         )
+        self.assertEqual(powerport.cf['cf1'], 'foo')
 
 
-        PowerOutlet.objects.get(
-            device=d,
+        poweroutlet = PowerOutlet.objects.get(
+            device=device,
             name='Power Outlet 1',
             name='Power Outlet 1',
-            power_port=pp,
+            power_port=powerport,
             feed_leg=PowerOutletFeedLegChoices.FEED_LEG_A
             feed_leg=PowerOutletFeedLegChoices.FEED_LEG_A
         )
         )
+        self.assertEqual(poweroutlet.cf['cf1'], 'foo')
 
 
-        Interface.objects.get(
-            device=d,
+        interface = Interface.objects.get(
+            device=device,
             name='Interface 1',
             name='Interface 1',
             type=InterfaceTypeChoices.TYPE_1GE_FIXED,
             type=InterfaceTypeChoices.TYPE_1GE_FIXED,
             mgmt_only=True
             mgmt_only=True
         )
         )
+        self.assertEqual(interface.cf['cf1'], 'foo')
 
 
-        rp = RearPort.objects.get(
-            device=d,
+        rearport = RearPort.objects.get(
+            device=device,
             name='Rear Port 1',
             name='Rear Port 1',
             type=PortTypeChoices.TYPE_8P8C,
             type=PortTypeChoices.TYPE_8P8C,
             positions=8
             positions=8
         )
         )
+        self.assertEqual(rearport.cf['cf1'], 'foo')
 
 
-        FrontPort.objects.get(
-            device=d,
+        frontport = FrontPort.objects.get(
+            device=device,
             name='Front Port 1',
             name='Front Port 1',
             type=PortTypeChoices.TYPE_8P8C,
             type=PortTypeChoices.TYPE_8P8C,
-            rear_port=rp,
+            rear_port=rearport,
             rear_port_position=2
             rear_port_position=2
         )
         )
+        self.assertEqual(frontport.cf['cf1'], 'foo')
 
 
-        ModuleBay.objects.get(
-            device=d,
+        modulebay = ModuleBay.objects.get(
+            device=device,
             name='Module Bay 1'
             name='Module Bay 1'
         )
         )
+        self.assertEqual(modulebay.cf['cf1'], 'foo')
 
 
-        DeviceBay.objects.get(
-            device=d,
+        devicebay = DeviceBay.objects.get(
+            device=device,
             name='Device Bay 1'
             name='Device Bay 1'
         )
         )
+        self.assertEqual(devicebay.cf['cf1'], 'foo')
+
+        inventoryitem = InventoryItem.objects.get(
+            device=device,
+            name='Inventory Item 1'
+        )
+        self.assertEqual(inventoryitem.cf['cf1'], 'foo')
 
 
     def test_multiple_unnamed_devices(self):
     def test_multiple_unnamed_devices(self):
 
 

+ 9 - 0
netbox/extras/models/customfields.py

@@ -57,6 +57,15 @@ class CustomFieldManager(models.Manager.from_queryset(RestrictedQuerySet)):
         content_type = ContentType.objects.get_for_model(model._meta.concrete_model)
         content_type = ContentType.objects.get_for_model(model._meta.concrete_model)
         return self.get_queryset().filter(content_types=content_type)
         return self.get_queryset().filter(content_types=content_type)
 
 
+    def get_defaults_for_model(self, model):
+        """
+        Return a dictionary of serialized default values for all CustomFields applicable to the given model.
+        """
+        custom_fields = self.get_for_model(model).filter(default__isnull=False)
+        return {
+            cf.name: cf.default for cf in custom_fields
+        }
+
 
 
 class CustomField(CloningMixin, ExportTemplatesMixin, ChangeLoggedModel):
 class CustomField(CloningMixin, ExportTemplatesMixin, ChangeLoggedModel):
     content_types = models.ManyToManyField(
     content_types = models.ManyToManyField(