Jeremy Stretch 5 лет назад
Родитель
Сommit
c85a45e520

+ 2 - 2
netbox/extras/api/customfields.py

@@ -148,10 +148,10 @@ class CustomFieldModelSerializer(ValidatedModelSerializer):
             fields = CustomField.objects.filter(obj_type=content_type)
 
             # Populate CustomFieldValues for each instance from database
-            try:
+            if type(self.instance) in (list, tuple):
                 for obj in self.instance:
                     self._populate_custom_fields(obj, fields)
-            except TypeError:
+            else:
                 self._populate_custom_fields(self.instance, fields)
 
     def _populate_custom_fields(self, instance, custom_fields):

+ 3 - 16
netbox/extras/forms.py

@@ -57,26 +57,13 @@ class CustomFieldModelForm(forms.ModelForm):
             # Annotate the field in the list of CustomField form fields
             self.custom_fields.append(field_name)
 
-    def _save_custom_fields(self):
-
-        for field_name in self.custom_fields:
-            self.instance.custom_field_data[field_name[3:]] = self.cleaned_data[field_name]
-
     def save(self, commit=True):
 
-        # Cache custom field values on object prior to save to ensure change logging
+        # Save custom field data on instance
         for cf_name in self.custom_fields:
-            self.instance._cf[cf_name[3:]] = self.cleaned_data.get(cf_name)
-
-        obj = super().save(commit)
-
-        # Handle custom fields the same way we do M2M fields
-        if commit:
-            self._save_custom_fields()
-        else:
-            obj.save_custom_fields = self._save_custom_fields
+            self.instance.custom_field_data[cf_name[3:]] = self.cleaned_data.get(cf_name)
 
-        return obj
+        return super().save(commit)
 
 
 class CustomFieldModelCSVForm(CSVModelForm, CustomFieldModelForm):

+ 10 - 0
netbox/extras/models/customfields.py

@@ -1,3 +1,4 @@
+from collections import OrderedDict
 from datetime import date
 
 from django import forms
@@ -34,6 +35,15 @@ class CustomFieldModel(models.Model):
         """
         return self.custom_field_data
 
+    def get_custom_fields(self):
+        """
+        Return a dictionary of custom fields for a single object in the form {<field>: value}.
+        """
+        fields = CustomField.objects.get_for_model(self)
+        return OrderedDict([
+            (field, self.custom_field_data.get(field.name)) for field in fields
+        ])
+
 
 class CustomFieldManager(models.Manager):
     use_in_migrations = True

+ 52 - 74
netbox/extras/tests/test_customfields.py

@@ -174,7 +174,7 @@ class CustomFieldAPITest(APITestCase):
         }
         cls.sites[1].save()
 
-    def test_get_single_object_without_custom_field_values(self):
+    def test_get_single_object_without_custom_field_data(self):
         """
         Validate that custom fields are present on an object even if it has no values defined.
         """
@@ -192,13 +192,11 @@ class CustomFieldAPITest(APITestCase):
             'choice_field': None,
         })
 
-    def test_get_single_object_with_custom_field_values(self):
+    def test_get_single_object_with_custom_field_data(self):
         """
         Validate that custom fields are present and correctly set for an object with values defined.
         """
-        site2_cfvs = {
-            cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all()
-        }
+        site2_cfvs = self.sites[1].custom_field_data
         url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk})
         self.add_permissions('dcim.view_site')
 
@@ -236,15 +234,12 @@ class CustomFieldAPITest(APITestCase):
 
         # Validate database data
         site = Site.objects.get(pk=response.data['id'])
-        cfvs = {
-            cfv.field.name: cfv.value for cfv in site.custom_field_values.all()
-        }
-        self.assertEqual(cfvs['text_field'], self.cf_text.default)
-        self.assertEqual(cfvs['number_field'], self.cf_integer.default)
-        self.assertEqual(cfvs['boolean_field'], self.cf_boolean.default)
-        self.assertEqual(str(cfvs['date_field']), self.cf_date.default)
-        self.assertEqual(cfvs['url_field'], self.cf_url.default)
-        self.assertEqual(cfvs['choice_field'].pk, self.cf_select_choice1.pk)
+        self.assertEqual(site.custom_field_data['text_field'], self.cf_text.default)
+        self.assertEqual(site.custom_field_data['number_field'], self.cf_integer.default)
+        self.assertEqual(site.custom_field_data['boolean_field'], self.cf_boolean.default)
+        self.assertEqual(str(site.custom_field_data['date_field']), self.cf_date.default)
+        self.assertEqual(site.custom_field_data['url_field'], self.cf_url.default)
+        self.assertEqual(site.custom_field_data['choice_field'].pk, self.cf_select_choice1.pk)
 
     def test_create_single_object_with_values(self):
         """
@@ -280,15 +275,12 @@ class CustomFieldAPITest(APITestCase):
 
         # Validate database data
         site = Site.objects.get(pk=response.data['id'])
-        cfvs = {
-            cfv.field.name: cfv.value for cfv in site.custom_field_values.all()
-        }
-        self.assertEqual(cfvs['text_field'], data_cf['text_field'])
-        self.assertEqual(cfvs['number_field'], data_cf['number_field'])
-        self.assertEqual(cfvs['boolean_field'], data_cf['boolean_field'])
-        self.assertEqual(str(cfvs['date_field']), data_cf['date_field'])
-        self.assertEqual(cfvs['url_field'], data_cf['url_field'])
-        self.assertEqual(cfvs['choice_field'].pk, data_cf['choice_field'])
+        self.assertEqual(site.custom_field_data['text_field'], data_cf['text_field'])
+        self.assertEqual(site.custom_field_data['number_field'], data_cf['number_field'])
+        self.assertEqual(site.custom_field_data['boolean_field'], data_cf['boolean_field'])
+        self.assertEqual(str(site.custom_field_data['date_field']), data_cf['date_field'])
+        self.assertEqual(site.custom_field_data['url_field'], data_cf['url_field'])
+        self.assertEqual(site.custom_field_data['choice_field'].pk, data_cf['choice_field'])
 
     def test_create_multiple_objects_with_defaults(self):
         """
@@ -329,15 +321,12 @@ class CustomFieldAPITest(APITestCase):
 
             # Validate database data
             site = Site.objects.get(pk=response.data[i]['id'])
-            cfvs = {
-                cfv.field.name: cfv.value for cfv in site.custom_field_values.all()
-            }
-            self.assertEqual(cfvs['text_field'], self.cf_text.default)
-            self.assertEqual(cfvs['number_field'], self.cf_integer.default)
-            self.assertEqual(cfvs['boolean_field'], self.cf_boolean.default)
-            self.assertEqual(str(cfvs['date_field']), self.cf_date.default)
-            self.assertEqual(cfvs['url_field'], self.cf_url.default)
-            self.assertEqual(cfvs['choice_field'].pk, self.cf_select_choice1.pk)
+            self.assertEqual(site.custom_field_data['text_field'], self.cf_text.default)
+            self.assertEqual(site.custom_field_data['number_field'], self.cf_integer.default)
+            self.assertEqual(site.custom_field_data['boolean_field'], self.cf_boolean.default)
+            self.assertEqual(str(site.custom_field_data['date_field']), self.cf_date.default)
+            self.assertEqual(site.custom_field_data['url_field'], self.cf_url.default)
+            self.assertEqual(site.custom_field_data['choice_field'].pk, self.cf_select_choice1.pk)
 
     def test_create_multiple_objects_with_values(self):
         """
@@ -388,24 +377,20 @@ class CustomFieldAPITest(APITestCase):
 
             # Validate database data
             site = Site.objects.get(pk=response.data[i]['id'])
-            cfvs = {
-                cfv.field.name: cfv.value for cfv in site.custom_field_values.all()
-            }
-            self.assertEqual(cfvs['text_field'], custom_field_data['text_field'])
-            self.assertEqual(cfvs['number_field'], custom_field_data['number_field'])
-            self.assertEqual(cfvs['boolean_field'], custom_field_data['boolean_field'])
-            self.assertEqual(str(cfvs['date_field']), custom_field_data['date_field'])
-            self.assertEqual(cfvs['url_field'], custom_field_data['url_field'])
-            self.assertEqual(cfvs['choice_field'].pk, custom_field_data['choice_field'])
+            self.assertEqual(site.custom_field_data['text_field'], custom_field_data['text_field'])
+            self.assertEqual(site.custom_field_data['number_field'], custom_field_data['number_field'])
+            self.assertEqual(site.custom_field_data['boolean_field'], custom_field_data['boolean_field'])
+            self.assertEqual(str(site.custom_field_data['date_field']), custom_field_data['date_field'])
+            self.assertEqual(site.custom_field_data['url_field'], custom_field_data['url_field'])
+            self.assertEqual(site.custom_field_data['choice_field'].pk, custom_field_data['choice_field'])
 
     def test_update_single_object_with_values(self):
         """
         Update an object with existing custom field values. Ensure that only the updated custom field values are
         modified.
         """
-        site2_original_cfvs = {
-            cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all()
-        }
+        site = self.sites[1]
+        original_cfvs = {**site.custom_field_data}
         data = {
             'custom_fields': {
                 'text_field': 'ABCD',
@@ -430,15 +415,13 @@ class CustomFieldAPITest(APITestCase):
         # self.assertEqual(response_cf['choice_field']['label'], site2_original_cfvs['choice_field'].value)
 
         # Validate database data
-        site2_updated_cfvs = {
-            cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all()
-        }
-        self.assertEqual(site2_updated_cfvs['text_field'], data_cf['text_field'])
-        self.assertEqual(site2_updated_cfvs['number_field'], data_cf['number_field'])
-        self.assertEqual(site2_updated_cfvs['boolean_field'], site2_original_cfvs['boolean_field'])
-        self.assertEqual(site2_updated_cfvs['date_field'], site2_original_cfvs['date_field'])
-        self.assertEqual(site2_updated_cfvs['url_field'], site2_original_cfvs['url_field'])
-        self.assertEqual(site2_updated_cfvs['choice_field'], site2_original_cfvs['choice_field'])
+        site.refresh_from_db()
+        self.assertEqual(site.custom_field_data['text_field'], data_cf['text_field'])
+        self.assertEqual(site.custom_field_data['number_field'], data_cf['number_field'])
+        self.assertEqual(site.custom_field_data['boolean_field'], original_cfvs['boolean_field'])
+        self.assertEqual(site.custom_field_data['date_field'], original_cfvs['date_field'])
+        self.assertEqual(site.custom_field_data['url_field'], original_cfvs['url_field'])
+        self.assertEqual(site.custom_field_data['choice_field'], original_cfvs['choice_field'])
 
 
 class CustomFieldChoiceAPITest(APITestCase):
@@ -514,31 +497,26 @@ class CustomFieldImportTest(TestCase):
         self.assertEqual(response.status_code, 200)
 
         # Validate data for site 1
-        custom_field_values = {
-            cf.name: value for cf, value in Site.objects.get(name='Site 1').custom_field_data
-        }
-        self.assertEqual(len(custom_field_values), 6)
-        self.assertEqual(custom_field_values['text'], 'ABC')
-        self.assertEqual(custom_field_values['integer'], 123)
-        self.assertEqual(custom_field_values['boolean'], True)
-        self.assertEqual(custom_field_values['date'], date(2020, 1, 1))
-        self.assertEqual(custom_field_values['url'], 'http://example.com/1')
-        self.assertEqual(custom_field_values['select'].value, 'Choice A')
+        site1 = Site.objects.get(name='Site 1')
+        self.assertEqual(len(site1.custom_field_data), 6)
+        self.assertEqual(site1.custom_field_data['text'], 'ABC')
+        self.assertEqual(site1.custom_field_data['integer'], 123)
+        self.assertEqual(site1.custom_field_data['boolean'], True)
+        self.assertEqual(site1.custom_field_data['date'], date(2020, 1, 1))
+        self.assertEqual(site1.custom_field_data['url'], 'http://example.com/1')
+        self.assertEqual(site1.custom_field_data['select'].value, 'Choice A')
 
         # Validate data for site 2
-        custom_field_values = {
-            cf.name: value for cf, value in Site.objects.get(name='Site 2').custom_field_data
-        }
-        self.assertEqual(len(custom_field_values), 6)
-        self.assertEqual(custom_field_values['text'], 'DEF')
-        self.assertEqual(custom_field_values['integer'], 456)
-        self.assertEqual(custom_field_values['boolean'], False)
-        self.assertEqual(custom_field_values['date'], date(2020, 1, 2))
-        self.assertEqual(custom_field_values['url'], 'http://example.com/2')
-        self.assertEqual(custom_field_values['select'].value, 'Choice B')
+        site2 = Site.objects.get(name='Site 2')
+        self.assertEqual(len(site2.custom_field_data), 6)
+        self.assertEqual(site2.custom_field_data['text'], 'DEF')
+        self.assertEqual(site2.custom_field_data['integer'], 456)
+        self.assertEqual(site2.custom_field_data['boolean'], False)
+        self.assertEqual(site2.custom_field_data['date'], date(2020, 1, 2))
+        self.assertEqual(site2.custom_field_data['url'], 'http://example.com/2')
+        self.assertEqual(site2.custom_field_data['select'].value, 'Choice B')
 
         # No CustomFieldValues should be created for site 3
-        obj_type = ContentType.objects.get_for_model(Site)
         site3 = Site.objects.get(name='Site 3')
         self.assertEqual(site3.custom_field_data, {})