Jeremy Stretch 5 лет назад
Родитель
Сommit
c85a45e520

+ 2 - 2
netbox/extras/api/customfields.py

@@ -148,10 +148,10 @@ class CustomFieldModelSerializer(ValidatedModelSerializer):
             fields = CustomField.objects.filter(obj_type=content_type)
             fields = CustomField.objects.filter(obj_type=content_type)
 
 
             # Populate CustomFieldValues for each instance from database
             # Populate CustomFieldValues for each instance from database
-            try:
+            if type(self.instance) in (list, tuple):
                 for obj in self.instance:
                 for obj in self.instance:
                     self._populate_custom_fields(obj, fields)
                     self._populate_custom_fields(obj, fields)
-            except TypeError:
+            else:
                 self._populate_custom_fields(self.instance, fields)
                 self._populate_custom_fields(self.instance, fields)
 
 
     def _populate_custom_fields(self, instance, custom_fields):
     def _populate_custom_fields(self, instance, custom_fields):

+ 3 - 16
netbox/extras/forms.py

@@ -57,26 +57,13 @@ class CustomFieldModelForm(forms.ModelForm):
             # Annotate the field in the list of CustomField form fields
             # Annotate the field in the list of CustomField form fields
             self.custom_fields.append(field_name)
             self.custom_fields.append(field_name)
 
 
-    def _save_custom_fields(self):
-
-        for field_name in self.custom_fields:
-            self.instance.custom_field_data[field_name[3:]] = self.cleaned_data[field_name]
-
     def save(self, commit=True):
     def save(self, commit=True):
 
 
-        # Cache custom field values on object prior to save to ensure change logging
+        # Save custom field data on instance
         for cf_name in self.custom_fields:
         for cf_name in self.custom_fields:
-            self.instance._cf[cf_name[3:]] = self.cleaned_data.get(cf_name)
-
-        obj = super().save(commit)
-
-        # Handle custom fields the same way we do M2M fields
-        if commit:
-            self._save_custom_fields()
-        else:
-            obj.save_custom_fields = self._save_custom_fields
+            self.instance.custom_field_data[cf_name[3:]] = self.cleaned_data.get(cf_name)
 
 
-        return obj
+        return super().save(commit)
 
 
 
 
 class CustomFieldModelCSVForm(CSVModelForm, CustomFieldModelForm):
 class CustomFieldModelCSVForm(CSVModelForm, CustomFieldModelForm):

+ 10 - 0
netbox/extras/models/customfields.py

@@ -1,3 +1,4 @@
+from collections import OrderedDict
 from datetime import date
 from datetime import date
 
 
 from django import forms
 from django import forms
@@ -34,6 +35,15 @@ class CustomFieldModel(models.Model):
         """
         """
         return self.custom_field_data
         return self.custom_field_data
 
 
+    def get_custom_fields(self):
+        """
+        Return a dictionary of custom fields for a single object in the form {<field>: value}.
+        """
+        fields = CustomField.objects.get_for_model(self)
+        return OrderedDict([
+            (field, self.custom_field_data.get(field.name)) for field in fields
+        ])
+
 
 
 class CustomFieldManager(models.Manager):
 class CustomFieldManager(models.Manager):
     use_in_migrations = True
     use_in_migrations = True

+ 52 - 74
netbox/extras/tests/test_customfields.py

@@ -174,7 +174,7 @@ class CustomFieldAPITest(APITestCase):
         }
         }
         cls.sites[1].save()
         cls.sites[1].save()
 
 
-    def test_get_single_object_without_custom_field_values(self):
+    def test_get_single_object_without_custom_field_data(self):
         """
         """
         Validate that custom fields are present on an object even if it has no values defined.
         Validate that custom fields are present on an object even if it has no values defined.
         """
         """
@@ -192,13 +192,11 @@ class CustomFieldAPITest(APITestCase):
             'choice_field': None,
             'choice_field': None,
         })
         })
 
 
-    def test_get_single_object_with_custom_field_values(self):
+    def test_get_single_object_with_custom_field_data(self):
         """
         """
         Validate that custom fields are present and correctly set for an object with values defined.
         Validate that custom fields are present and correctly set for an object with values defined.
         """
         """
-        site2_cfvs = {
-            cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all()
-        }
+        site2_cfvs = self.sites[1].custom_field_data
         url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk})
         self.add_permissions('dcim.view_site')
         self.add_permissions('dcim.view_site')
 
 
@@ -236,15 +234,12 @@ class CustomFieldAPITest(APITestCase):
 
 
         # Validate database data
         # Validate database data
         site = Site.objects.get(pk=response.data['id'])
         site = Site.objects.get(pk=response.data['id'])
-        cfvs = {
-            cfv.field.name: cfv.value for cfv in site.custom_field_values.all()
-        }
-        self.assertEqual(cfvs['text_field'], self.cf_text.default)
-        self.assertEqual(cfvs['number_field'], self.cf_integer.default)
-        self.assertEqual(cfvs['boolean_field'], self.cf_boolean.default)
-        self.assertEqual(str(cfvs['date_field']), self.cf_date.default)
-        self.assertEqual(cfvs['url_field'], self.cf_url.default)
-        self.assertEqual(cfvs['choice_field'].pk, self.cf_select_choice1.pk)
+        self.assertEqual(site.custom_field_data['text_field'], self.cf_text.default)
+        self.assertEqual(site.custom_field_data['number_field'], self.cf_integer.default)
+        self.assertEqual(site.custom_field_data['boolean_field'], self.cf_boolean.default)
+        self.assertEqual(str(site.custom_field_data['date_field']), self.cf_date.default)
+        self.assertEqual(site.custom_field_data['url_field'], self.cf_url.default)
+        self.assertEqual(site.custom_field_data['choice_field'].pk, self.cf_select_choice1.pk)
 
 
     def test_create_single_object_with_values(self):
     def test_create_single_object_with_values(self):
         """
         """
@@ -280,15 +275,12 @@ class CustomFieldAPITest(APITestCase):
 
 
         # Validate database data
         # Validate database data
         site = Site.objects.get(pk=response.data['id'])
         site = Site.objects.get(pk=response.data['id'])
-        cfvs = {
-            cfv.field.name: cfv.value for cfv in site.custom_field_values.all()
-        }
-        self.assertEqual(cfvs['text_field'], data_cf['text_field'])
-        self.assertEqual(cfvs['number_field'], data_cf['number_field'])
-        self.assertEqual(cfvs['boolean_field'], data_cf['boolean_field'])
-        self.assertEqual(str(cfvs['date_field']), data_cf['date_field'])
-        self.assertEqual(cfvs['url_field'], data_cf['url_field'])
-        self.assertEqual(cfvs['choice_field'].pk, data_cf['choice_field'])
+        self.assertEqual(site.custom_field_data['text_field'], data_cf['text_field'])
+        self.assertEqual(site.custom_field_data['number_field'], data_cf['number_field'])
+        self.assertEqual(site.custom_field_data['boolean_field'], data_cf['boolean_field'])
+        self.assertEqual(str(site.custom_field_data['date_field']), data_cf['date_field'])
+        self.assertEqual(site.custom_field_data['url_field'], data_cf['url_field'])
+        self.assertEqual(site.custom_field_data['choice_field'].pk, data_cf['choice_field'])
 
 
     def test_create_multiple_objects_with_defaults(self):
     def test_create_multiple_objects_with_defaults(self):
         """
         """
@@ -329,15 +321,12 @@ class CustomFieldAPITest(APITestCase):
 
 
             # Validate database data
             # Validate database data
             site = Site.objects.get(pk=response.data[i]['id'])
             site = Site.objects.get(pk=response.data[i]['id'])
-            cfvs = {
-                cfv.field.name: cfv.value for cfv in site.custom_field_values.all()
-            }
-            self.assertEqual(cfvs['text_field'], self.cf_text.default)
-            self.assertEqual(cfvs['number_field'], self.cf_integer.default)
-            self.assertEqual(cfvs['boolean_field'], self.cf_boolean.default)
-            self.assertEqual(str(cfvs['date_field']), self.cf_date.default)
-            self.assertEqual(cfvs['url_field'], self.cf_url.default)
-            self.assertEqual(cfvs['choice_field'].pk, self.cf_select_choice1.pk)
+            self.assertEqual(site.custom_field_data['text_field'], self.cf_text.default)
+            self.assertEqual(site.custom_field_data['number_field'], self.cf_integer.default)
+            self.assertEqual(site.custom_field_data['boolean_field'], self.cf_boolean.default)
+            self.assertEqual(str(site.custom_field_data['date_field']), self.cf_date.default)
+            self.assertEqual(site.custom_field_data['url_field'], self.cf_url.default)
+            self.assertEqual(site.custom_field_data['choice_field'].pk, self.cf_select_choice1.pk)
 
 
     def test_create_multiple_objects_with_values(self):
     def test_create_multiple_objects_with_values(self):
         """
         """
@@ -388,24 +377,20 @@ class CustomFieldAPITest(APITestCase):
 
 
             # Validate database data
             # Validate database data
             site = Site.objects.get(pk=response.data[i]['id'])
             site = Site.objects.get(pk=response.data[i]['id'])
-            cfvs = {
-                cfv.field.name: cfv.value for cfv in site.custom_field_values.all()
-            }
-            self.assertEqual(cfvs['text_field'], custom_field_data['text_field'])
-            self.assertEqual(cfvs['number_field'], custom_field_data['number_field'])
-            self.assertEqual(cfvs['boolean_field'], custom_field_data['boolean_field'])
-            self.assertEqual(str(cfvs['date_field']), custom_field_data['date_field'])
-            self.assertEqual(cfvs['url_field'], custom_field_data['url_field'])
-            self.assertEqual(cfvs['choice_field'].pk, custom_field_data['choice_field'])
+            self.assertEqual(site.custom_field_data['text_field'], custom_field_data['text_field'])
+            self.assertEqual(site.custom_field_data['number_field'], custom_field_data['number_field'])
+            self.assertEqual(site.custom_field_data['boolean_field'], custom_field_data['boolean_field'])
+            self.assertEqual(str(site.custom_field_data['date_field']), custom_field_data['date_field'])
+            self.assertEqual(site.custom_field_data['url_field'], custom_field_data['url_field'])
+            self.assertEqual(site.custom_field_data['choice_field'].pk, custom_field_data['choice_field'])
 
 
     def test_update_single_object_with_values(self):
     def test_update_single_object_with_values(self):
         """
         """
         Update an object with existing custom field values. Ensure that only the updated custom field values are
         Update an object with existing custom field values. Ensure that only the updated custom field values are
         modified.
         modified.
         """
         """
-        site2_original_cfvs = {
-            cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all()
-        }
+        site = self.sites[1]
+        original_cfvs = {**site.custom_field_data}
         data = {
         data = {
             'custom_fields': {
             'custom_fields': {
                 'text_field': 'ABCD',
                 'text_field': 'ABCD',
@@ -430,15 +415,13 @@ class CustomFieldAPITest(APITestCase):
         # self.assertEqual(response_cf['choice_field']['label'], site2_original_cfvs['choice_field'].value)
         # self.assertEqual(response_cf['choice_field']['label'], site2_original_cfvs['choice_field'].value)
 
 
         # Validate database data
         # Validate database data
-        site2_updated_cfvs = {
-            cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all()
-        }
-        self.assertEqual(site2_updated_cfvs['text_field'], data_cf['text_field'])
-        self.assertEqual(site2_updated_cfvs['number_field'], data_cf['number_field'])
-        self.assertEqual(site2_updated_cfvs['boolean_field'], site2_original_cfvs['boolean_field'])
-        self.assertEqual(site2_updated_cfvs['date_field'], site2_original_cfvs['date_field'])
-        self.assertEqual(site2_updated_cfvs['url_field'], site2_original_cfvs['url_field'])
-        self.assertEqual(site2_updated_cfvs['choice_field'], site2_original_cfvs['choice_field'])
+        site.refresh_from_db()
+        self.assertEqual(site.custom_field_data['text_field'], data_cf['text_field'])
+        self.assertEqual(site.custom_field_data['number_field'], data_cf['number_field'])
+        self.assertEqual(site.custom_field_data['boolean_field'], original_cfvs['boolean_field'])
+        self.assertEqual(site.custom_field_data['date_field'], original_cfvs['date_field'])
+        self.assertEqual(site.custom_field_data['url_field'], original_cfvs['url_field'])
+        self.assertEqual(site.custom_field_data['choice_field'], original_cfvs['choice_field'])
 
 
 
 
 class CustomFieldChoiceAPITest(APITestCase):
 class CustomFieldChoiceAPITest(APITestCase):
@@ -514,31 +497,26 @@ class CustomFieldImportTest(TestCase):
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.status_code, 200)
 
 
         # Validate data for site 1
         # Validate data for site 1
-        custom_field_values = {
-            cf.name: value for cf, value in Site.objects.get(name='Site 1').custom_field_data
-        }
-        self.assertEqual(len(custom_field_values), 6)
-        self.assertEqual(custom_field_values['text'], 'ABC')
-        self.assertEqual(custom_field_values['integer'], 123)
-        self.assertEqual(custom_field_values['boolean'], True)
-        self.assertEqual(custom_field_values['date'], date(2020, 1, 1))
-        self.assertEqual(custom_field_values['url'], 'http://example.com/1')
-        self.assertEqual(custom_field_values['select'].value, 'Choice A')
+        site1 = Site.objects.get(name='Site 1')
+        self.assertEqual(len(site1.custom_field_data), 6)
+        self.assertEqual(site1.custom_field_data['text'], 'ABC')
+        self.assertEqual(site1.custom_field_data['integer'], 123)
+        self.assertEqual(site1.custom_field_data['boolean'], True)
+        self.assertEqual(site1.custom_field_data['date'], date(2020, 1, 1))
+        self.assertEqual(site1.custom_field_data['url'], 'http://example.com/1')
+        self.assertEqual(site1.custom_field_data['select'].value, 'Choice A')
 
 
         # Validate data for site 2
         # Validate data for site 2
-        custom_field_values = {
-            cf.name: value for cf, value in Site.objects.get(name='Site 2').custom_field_data
-        }
-        self.assertEqual(len(custom_field_values), 6)
-        self.assertEqual(custom_field_values['text'], 'DEF')
-        self.assertEqual(custom_field_values['integer'], 456)
-        self.assertEqual(custom_field_values['boolean'], False)
-        self.assertEqual(custom_field_values['date'], date(2020, 1, 2))
-        self.assertEqual(custom_field_values['url'], 'http://example.com/2')
-        self.assertEqual(custom_field_values['select'].value, 'Choice B')
+        site2 = Site.objects.get(name='Site 2')
+        self.assertEqual(len(site2.custom_field_data), 6)
+        self.assertEqual(site2.custom_field_data['text'], 'DEF')
+        self.assertEqual(site2.custom_field_data['integer'], 456)
+        self.assertEqual(site2.custom_field_data['boolean'], False)
+        self.assertEqual(site2.custom_field_data['date'], date(2020, 1, 2))
+        self.assertEqual(site2.custom_field_data['url'], 'http://example.com/2')
+        self.assertEqual(site2.custom_field_data['select'].value, 'Choice B')
 
 
         # No CustomFieldValues should be created for site 3
         # No CustomFieldValues should be created for site 3
-        obj_type = ContentType.objects.get_for_model(Site)
         site3 = Site.objects.get(name='Site 3')
         site3 = Site.objects.get(name='Site 3')
         self.assertEqual(site3.custom_field_data, {})
         self.assertEqual(site3.custom_field_data, {})