Просмотр исходного кода

Introduced SerializedPKRelatedField to represent serialized ManyToManyFields

Jeremy Stretch 8 лет назад
Родитель
Сommit
9de1a8c363
4 измененных файлов с 41 добавлено и 17 удалено
  1. 10 2
      netbox/dcim/api/serializers.py
  2. 9 13
      netbox/dcim/tests/test_api.py
  3. 7 1
      netbox/ipam/api/serializers.py
  4. 15 1
      netbox/utilities/api.py

+ 10 - 2
netbox/dcim/api/serializers.py

@@ -20,7 +20,9 @@ from extras.api.customfields import CustomFieldModelSerializer
 from ipam.models import IPAddress, VLAN
 from ipam.models import IPAddress, VLAN
 from tenancy.api.serializers import NestedTenantSerializer
 from tenancy.api.serializers import NestedTenantSerializer
 from users.api.serializers import NestedUserSerializer
 from users.api.serializers import NestedUserSerializer
-from utilities.api import ChoiceFieldSerializer, TimeZoneField, ValidatedModelSerializer, WritableNestedSerializer
+from utilities.api import (
+    ChoiceFieldSerializer, SerializedPKRelatedField, TimeZoneField, ValidatedModelSerializer, WritableNestedSerializer,
+)
 from virtualization.models import Cluster
 from virtualization.models import Cluster
 
 
 
 
@@ -551,8 +553,14 @@ class InterfaceSerializer(ValidatedModelSerializer):
     is_connected = serializers.SerializerMethodField(read_only=True)
     is_connected = serializers.SerializerMethodField(read_only=True)
     interface_connection = serializers.SerializerMethodField(read_only=True)
     interface_connection = serializers.SerializerMethodField(read_only=True)
     circuit_termination = InterfaceCircuitTerminationSerializer(read_only=True)
     circuit_termination = InterfaceCircuitTerminationSerializer(read_only=True)
-    untagged_vlan = InterfaceVLANSerializer(required=False, allow_null=True)
     mode = ChoiceFieldSerializer(choices=IFACE_MODE_CHOICES, required=False)
     mode = ChoiceFieldSerializer(choices=IFACE_MODE_CHOICES, required=False)
+    untagged_vlan = InterfaceVLANSerializer(required=False, allow_null=True)
+    tagged_vlans = SerializedPKRelatedField(
+        queryset=VLAN.objects.all(),
+        serializer=InterfaceVLANSerializer,
+        required=False,
+        many=True
+    )
 
 
     class Meta:
     class Meta:
         model = Interface
         model = Interface

+ 9 - 13
netbox/dcim/tests/test_api.py

@@ -1,7 +1,6 @@
 from __future__ import unicode_literals
 from __future__ import unicode_literals
 
 
 from django.contrib.auth.models import User
 from django.contrib.auth.models import User
-from django.test.utils import override_settings
 from django.urls import reverse
 from django.urls import reverse
 from rest_framework import status
 from rest_framework import status
 from rest_framework.test import APITestCase
 from rest_framework.test import APITestCase
@@ -2322,15 +2321,14 @@ class InterfaceTest(HttpStatusMixin, APITestCase):
         self.assertEqual(interface4.device_id, data['device'])
         self.assertEqual(interface4.device_id, data['device'])
         self.assertEqual(interface4.name, data['name'])
         self.assertEqual(interface4.name, data['name'])
 
 
-    @override_settings(DEBUG=True)
     def test_create_interface_with_802_1q(self):
     def test_create_interface_with_802_1q(self):
 
 
         data = {
         data = {
             'device': self.device.pk,
             'device': self.device.pk,
             'name': 'Test Interface 4',
             'name': 'Test Interface 4',
             'mode': IFACE_MODE_TAGGED,
             'mode': IFACE_MODE_TAGGED,
+            'untagged_vlan': self.vlan3.id,
             'tagged_vlans': [self.vlan1.id, self.vlan2.id],
             'tagged_vlans': [self.vlan1.id, self.vlan2.id],
-            'untagged_vlan': self.vlan3.id
         }
         }
 
 
         url = reverse('dcim-api:interface-list')
         url = reverse('dcim-api:interface-list')
@@ -2338,11 +2336,10 @@ class InterfaceTest(HttpStatusMixin, APITestCase):
 
 
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertEqual(Interface.objects.count(), 4)
         self.assertEqual(Interface.objects.count(), 4)
-        interface5 = Interface.objects.get(pk=response.data['id'])
-        self.assertEqual(interface5.device_id, data['device'])
-        self.assertEqual(interface5.name, data['name'])
-        self.assertEqual(interface5.tagged_vlans.count(), 2)
-        self.assertEqual(interface5.untagged_vlan.id, data['untagged_vlan'])
+        self.assertEqual(response.data['device']['id'], data['device'])
+        self.assertEqual(response.data['name'], data['name'])
+        self.assertEqual(response.data['untagged_vlan']['id'], data['untagged_vlan'])
+        self.assertEqual([v['id'] for v in response.data['tagged_vlans']], data['tagged_vlans'])
 
 
     def test_create_interface_bulk(self):
     def test_create_interface_bulk(self):
 
 
@@ -2370,7 +2367,6 @@ class InterfaceTest(HttpStatusMixin, APITestCase):
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[1]['name'], data[1]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
         self.assertEqual(response.data[2]['name'], data[2]['name'])
 
 
-    @override_settings(DEBUG=True)
     def test_create_interface_802_1q_bulk(self):
     def test_create_interface_802_1q_bulk(self):
 
 
         data = [
         data = [
@@ -2378,22 +2374,22 @@ class InterfaceTest(HttpStatusMixin, APITestCase):
                 'device': self.device.pk,
                 'device': self.device.pk,
                 'name': 'Test Interface 4',
                 'name': 'Test Interface 4',
                 'mode': IFACE_MODE_TAGGED,
                 'mode': IFACE_MODE_TAGGED,
-                'tagged_vlans': [self.vlan1.id],
                 'untagged_vlan': self.vlan2.id,
                 'untagged_vlan': self.vlan2.id,
+                'tagged_vlans': [self.vlan1.id],
             },
             },
             {
             {
                 'device': self.device.pk,
                 'device': self.device.pk,
                 'name': 'Test Interface 5',
                 'name': 'Test Interface 5',
                 'mode': IFACE_MODE_TAGGED,
                 'mode': IFACE_MODE_TAGGED,
-                'tagged_vlans': [self.vlan1.id],
                 'untagged_vlan': self.vlan2.id,
                 'untagged_vlan': self.vlan2.id,
+                'tagged_vlans': [self.vlan1.id],
             },
             },
             {
             {
                 'device': self.device.pk,
                 'device': self.device.pk,
                 'name': 'Test Interface 6',
                 'name': 'Test Interface 6',
                 'mode': IFACE_MODE_TAGGED,
                 'mode': IFACE_MODE_TAGGED,
-                'tagged_vlans': [self.vlan1.id],
                 'untagged_vlan': self.vlan2.id,
                 'untagged_vlan': self.vlan2.id,
+                'tagged_vlans': [self.vlan1.id],
             },
             },
         ]
         ]
 
 
@@ -2404,7 +2400,7 @@ class InterfaceTest(HttpStatusMixin, APITestCase):
         self.assertEqual(Interface.objects.count(), 6)
         self.assertEqual(Interface.objects.count(), 6)
         for i in range(0, 3):
         for i in range(0, 3):
             self.assertEqual(response.data[i]['name'], data[i]['name'])
             self.assertEqual(response.data[i]['name'], data[i]['name'])
-            self.assertEqual(response.data[i]['tagged_vlans'], data[i]['tagged_vlans'])
+            self.assertEqual([v['id'] for v in response.data[i]['tagged_vlans']], data[i]['tagged_vlans'])
             self.assertEqual(response.data[i]['untagged_vlan']['id'], data[i]['untagged_vlan'])
             self.assertEqual(response.data[i]['untagged_vlan']['id'], data[i]['untagged_vlan'])
 
 
     def test_update_interface(self):
     def test_update_interface(self):

+ 7 - 1
netbox/ipam/api/serializers.py

@@ -14,7 +14,7 @@ from ipam.constants import (
 )
 )
 from ipam.models import Aggregate, IPAddress, Prefix, RIR, Role, Service, VLAN, VLANGroup, VRF
 from ipam.models import Aggregate, IPAddress, Prefix, RIR, Role, Service, VLAN, VLANGroup, VRF
 from tenancy.api.serializers import NestedTenantSerializer
 from tenancy.api.serializers import NestedTenantSerializer
-from utilities.api import ChoiceFieldSerializer, ValidatedModelSerializer, WritableNestedSerializer
+from utilities.api import ChoiceFieldSerializer, SerializedPKRelatedField, ValidatedModelSerializer, WritableNestedSerializer
 from virtualization.api.serializers import NestedVirtualMachineSerializer
 from virtualization.api.serializers import NestedVirtualMachineSerializer
 
 
 
 
@@ -296,6 +296,12 @@ class ServiceSerializer(ValidatedModelSerializer):
     device = NestedDeviceSerializer(required=False, allow_null=True)
     device = NestedDeviceSerializer(required=False, allow_null=True)
     virtual_machine = NestedVirtualMachineSerializer(required=False, allow_null=True)
     virtual_machine = NestedVirtualMachineSerializer(required=False, allow_null=True)
     protocol = ChoiceFieldSerializer(choices=IP_PROTOCOL_CHOICES)
     protocol = ChoiceFieldSerializer(choices=IP_PROTOCOL_CHOICES)
+    ipaddresses = SerializedPKRelatedField(
+        queryset=IPAddress.objects.all(),
+        serializer=NestedIPAddressSerializer,
+        required=False,
+        many=True
+    )
 
 
     class Meta:
     class Meta:
         model = Service
         model = Service

+ 15 - 1
netbox/utilities/api.py

@@ -11,6 +11,7 @@ from django.http import Http404
 from rest_framework import mixins
 from rest_framework import mixins
 from rest_framework.exceptions import APIException
 from rest_framework.exceptions import APIException
 from rest_framework.permissions import BasePermission
 from rest_framework.permissions import BasePermission
+from rest_framework.relations import PrimaryKeyRelatedField
 from rest_framework.response import Response
 from rest_framework.response import Response
 from rest_framework.serializers import Field, ModelSerializer, ValidationError
 from rest_framework.serializers import Field, ModelSerializer, ValidationError
 from rest_framework.viewsets import GenericViewSet, ViewSet
 from rest_framework.viewsets import GenericViewSet, ViewSet
@@ -82,7 +83,6 @@ class TimeZoneField(Field):
     """
     """
     Represent a pytz time zone.
     Represent a pytz time zone.
     """
     """
-
     def to_representation(self, obj):
     def to_representation(self, obj):
         return obj.zone if obj else None
         return obj.zone if obj else None
 
 
@@ -95,6 +95,20 @@ class TimeZoneField(Field):
             raise ValidationError('Invalid time zone "{}"'.format(data))
             raise ValidationError('Invalid time zone "{}"'.format(data))
 
 
 
 
+class SerializedPKRelatedField(PrimaryKeyRelatedField):
+    """
+    Extends PrimaryKeyRelatedField to return a serialized object on read. This is useful for representing related
+    objects in a ManyToManyField while still allowing a set of primary keys to be written.
+    """
+    def __init__(self, serializer, **kwargs):
+        self.serializer = serializer
+        self.pk_field = kwargs.pop('pk_field', None)
+        super(SerializedPKRelatedField, self).__init__(**kwargs)
+
+    def to_representation(self, value):
+        return self.serializer(value, context={'request': self.context['request']}).data
+
+
 #
 #
 # Serializers
 # Serializers
 #
 #