Просмотр исходного кода

Introduce CustomFieldDefaultValues class to handle default custom field values

Jeremy Stretch 6 лет назад
Родитель
Сommit
e4abfd192e
1 измененных файлов с 50 добавлено и 37 удалено
  1. 50 37
      netbox/extras/api/customfields.py

+ 50 - 37
netbox/extras/api/customfields.py

@@ -4,6 +4,7 @@ from django.contrib.contenttypes.models import ContentType
 from django.db import transaction
 from django.db import transaction
 from rest_framework import serializers
 from rest_framework import serializers
 from rest_framework.exceptions import ValidationError
 from rest_framework.exceptions import ValidationError
+from rest_framework.fields import CreateOnlyDefault
 
 
 from extras.choices import *
 from extras.choices import *
 from extras.models import CustomField, CustomFieldChoice, CustomFieldValue
 from extras.models import CustomField, CustomFieldChoice, CustomFieldValue
@@ -14,6 +15,36 @@ from utilities.api import ValidatedModelSerializer
 # Custom fields
 # Custom fields
 #
 #
 
 
+class CustomFieldDefaultValues:
+    """
+    Return a dictionary of all CustomFields assigned to the parent model and their default values.
+    """
+    def __call__(self):
+
+        # Retrieve the CustomFields for the parent model
+        content_type = ContentType.objects.get_for_model(self.model)
+        fields = CustomField.objects.filter(obj_type=content_type)
+
+        # Populate the default value for each CustomField
+        value = {}
+        for field in fields:
+            if field.default:
+                if field.type == CustomFieldTypeChoices.TYPE_SELECT:
+                    field_value = field.choices.get(value=field.default).pk
+                elif field.type == CustomFieldTypeChoices.TYPE_BOOLEAN:
+                    field_value = bool(field.default)
+                else:
+                    field_value = field.default
+                value[field.name] = field_value
+            else:
+                value[field.name] = None
+
+        return value
+
+    def set_context(self, serializer_field):
+        self.model = serializer_field.parent.Meta.model
+
+
 class CustomFieldsSerializer(serializers.BaseSerializer):
 class CustomFieldsSerializer(serializers.BaseSerializer):
 
 
     def to_representation(self, obj):
     def to_representation(self, obj):
@@ -94,53 +125,35 @@ class CustomFieldModelSerializer(ValidatedModelSerializer):
     """
     """
     Extends ModelSerializer to render any CustomFields and their values associated with an object.
     Extends ModelSerializer to render any CustomFields and their values associated with an object.
     """
     """
-    custom_fields = CustomFieldsSerializer(required=False)
+    custom_fields = CustomFieldsSerializer(
+        required=False,
+        default=CreateOnlyDefault(CustomFieldDefaultValues())
+    )
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
-
-        def _populate_custom_fields(instance, fields):
-            instance.custom_fields = {}
-            for field in fields:
-                value = instance.cf.get(field.name)
-                if field.type == CustomFieldTypeChoices.TYPE_SELECT and value is not None:
-                    instance.custom_fields[field.name] = CustomFieldChoiceSerializer(value).data
-                else:
-                    instance.custom_fields[field.name] = value
-
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
 
 
-        # Retrieve the set of CustomFields which apply to this type of object
-        content_type = ContentType.objects.get_for_model(self.Meta.model)
-        fields = CustomField.objects.filter(obj_type=content_type)
-
         if self.instance is not None:
         if self.instance is not None:
 
 
+            # Retrieve the set of CustomFields which apply to this type of object
+            content_type = ContentType.objects.get_for_model(self.Meta.model)
+            fields = CustomField.objects.filter(obj_type=content_type)
+
             # Populate CustomFieldValues for each instance from database
             # Populate CustomFieldValues for each instance from database
             try:
             try:
                 for obj in self.instance:
                 for obj in self.instance:
-                    _populate_custom_fields(obj, fields)
+                    self._populate_custom_fields(obj, fields)
             except TypeError:
             except TypeError:
-                _populate_custom_fields(self.instance, fields)
-
-        else:
-
-            if not hasattr(self, 'initial_data'):
-                self.initial_data = {}
-
-            # Populate default values
-            if fields and 'custom_fields' not in self.initial_data:
-                self.initial_data['custom_fields'] = {}
-
-            # Populate initial data using custom field default values
-            for field in fields:
-                if field.name not in self.initial_data['custom_fields'] and field.default:
-                    if field.type == CustomFieldTypeChoices.TYPE_SELECT:
-                        field_value = field.choices.get(value=field.default).pk
-                    elif field.type == CustomFieldTypeChoices.TYPE_BOOLEAN:
-                        field_value = bool(field.default)
-                    else:
-                        field_value = field.default
-                    self.initial_data['custom_fields'][field.name] = field_value
+                self._populate_custom_fields(self.instance, fields)
+
+    def _populate_custom_fields(self, instance, custom_fields):
+        instance.custom_fields = {}
+        for field in custom_fields:
+            value = instance.cf.get(field.name)
+            if field.type == CustomFieldTypeChoices.TYPE_SELECT and value is not None:
+                instance.custom_fields[field.name] = CustomFieldChoiceSerializer(value).data
+            else:
+                instance.custom_fields[field.name] = value
 
 
     def _save_custom_fields(self, instance, custom_fields):
     def _save_custom_fields(self, instance, custom_fields):
         content_type = ContentType.objects.get_for_model(self.Meta.model)
         content_type = ContentType.objects.get_for_model(self.Meta.model)