Browse Source

feat(api): Include NAT IP fields in primary IP serializers

Add nat_inside and nat_outside fields to primary_ip, primary_ip4,
primary_ip6, and oob_ip on Device and VirtualMachine serializers.
Update prefetch logic to honor field-level constraints on nested
serializers and add test coverage for NAT field inclusion.

Fixes #19138
Martin Hauser 21 hours ago
parent
commit
ab94e3d40e

+ 24 - 4
netbox/dcim/api/serializers_/devices.py

@@ -58,10 +58,30 @@ class DeviceSerializer(PrimaryModelSerializer):
     )
     )
     status = ChoiceField(choices=DeviceStatusChoices, required=False)
     status = ChoiceField(choices=DeviceStatusChoices, required=False)
     airflow = ChoiceField(choices=DeviceAirflowChoices, allow_blank=True, required=False)
     airflow = ChoiceField(choices=DeviceAirflowChoices, allow_blank=True, required=False)
-    primary_ip = IPAddressSerializer(nested=True, read_only=True, allow_null=True)
-    primary_ip4 = IPAddressSerializer(nested=True, required=False, allow_null=True)
-    primary_ip6 = IPAddressSerializer(nested=True, required=False, allow_null=True)
-    oob_ip = IPAddressSerializer(nested=True, required=False, allow_null=True)
+    primary_ip = IPAddressSerializer(
+        nested=True,
+        read_only=True,
+        allow_null=True,
+        fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'],
+    )
+    primary_ip4 = IPAddressSerializer(
+        nested=True,
+        required=False,
+        allow_null=True,
+        fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'],
+    )
+    primary_ip6 = IPAddressSerializer(
+        nested=True,
+        required=False,
+        allow_null=True,
+        fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'],
+    )
+    oob_ip = IPAddressSerializer(
+        nested=True,
+        required=False,
+        allow_null=True,
+        fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'],
+    )
     parent_device = serializers.SerializerMethodField()
     parent_device = serializers.SerializerMethodField()
     cluster = ClusterSerializer(nested=True, required=False, allow_null=True)
     cluster = ClusterSerializer(nested=True, required=False, allow_null=True)
     virtual_chassis = VirtualChassisSerializer(nested=True, required=False, allow_null=True, default=None)
     virtual_chassis = VirtualChassisSerializer(nested=True, required=False, allow_null=True, default=None)

+ 82 - 1
netbox/dcim/tests/test_api.py

@@ -16,7 +16,13 @@ from netbox.api.serializers import GenericObjectSerializer
 from tenancy.models import Tenant
 from tenancy.models import Tenant
 from users.constants import TOKEN_PREFIX
 from users.constants import TOKEN_PREFIX
 from users.models import ObjectPermission, Token, User
 from users.models import ObjectPermission, Token, User
-from utilities.testing import APITestCase, APIViewTestCases, create_test_device, disable_logging
+from utilities.testing import (
+    APITestCase,
+    APIViewTestCases,
+    create_test_device,
+    create_test_nat_ip_pair,
+    disable_logging,
+)
 from virtualization.models import Cluster, ClusterType
 from virtualization.models import Cluster, ClusterType
 from wireless.choices import WirelessChannelChoices
 from wireless.choices import WirelessChannelChoices
 from wireless.models import WirelessLAN
 from wireless.models import WirelessLAN
@@ -1902,6 +1908,81 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
         response = self.client.post(url, {}, format='json', HTTP_AUTHORIZATION=token_header)
         response = self.client.post(url, {}, format='json', HTTP_AUTHORIZATION=token_header)
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
 
 
+    def test_list_object_includes_nat_inside_on_primary_ip(self):
+        device = create_test_device('natted-device')
+        interface = Interface.objects.create(device=device, name='eth0', type='other')
+
+        real_ip, nat_ip = create_test_nat_ip_pair(
+            real_address='10.0.0.10/32',
+            nat_address='198.51.100.10/32',
+            inside_interface=interface,
+        )
+
+        device.primary_ip4 = nat_ip
+        device.save()
+
+        self.add_permissions('dcim.view_device', 'ipam.view_ipaddress')
+        response = self.client.get(f'{self._get_list_url()}?id={device.pk}', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        result = response.data['results'][0]
+        for field in ('primary_ip', 'primary_ip4'):
+            self.assertEqual(result[field]['address'], str(nat_ip.address))
+            self.assertEqual(result[field]['nat_inside']['address'], str(real_ip.address))
+            self.assertEqual(result[field]['nat_outside'], [])
+
+    def test_get_object_includes_nat_outside_on_primary_ip(self):
+        device = create_test_device('real-ip-device')
+        interface = Interface.objects.create(device=device, name='eth0', type='other')
+
+        real_ip, nat_ip = create_test_nat_ip_pair(
+            real_address='10.0.0.11/32',
+            nat_address='198.51.100.11/32',
+            inside_interface=interface,
+        )
+
+        device.primary_ip4 = real_ip
+        device.save()
+
+        self.add_permissions('dcim.view_device', 'ipam.view_ipaddress')
+        response = self.client.get(
+            f'{self._get_detail_url(device)}?exclude=config_context',
+            **self.header,
+        )
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        for field in ('primary_ip', 'primary_ip4'):
+            self.assertEqual(response.data[field]['address'], str(real_ip.address))
+            self.assertIsNone(response.data[field]['nat_inside'])
+            self.assertCountEqual(
+                [ip['address'] for ip in response.data[field]['nat_outside']],
+                [str(nat_ip.address)],
+            )
+
+    def test_get_object_includes_nat_on_oob_ip(self):
+        device = create_test_device('oob-nat-device')
+        interface = Interface.objects.create(device=device, name='oob0', type='other')
+
+        real_ip, nat_ip = create_test_nat_ip_pair(
+            real_address='10.0.0.12/32',
+            nat_address='198.51.100.12/32',
+            inside_interface=interface,
+        )
+
+        device.oob_ip = nat_ip
+        device.save()
+
+        self.add_permissions('dcim.view_device', 'ipam.view_ipaddress')
+        response = self.client.get(
+            f'{self._get_detail_url(device)}?exclude=config_context',
+            **self.header,
+        )
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        self.assertEqual(response.data['oob_ip']['address'], str(nat_ip.address))
+        self.assertEqual(response.data['oob_ip']['nat_inside']['address'], str(real_ip.address))
+        self.assertEqual(response.data['oob_ip']['nat_outside'], [])
+
 
 
 class ModuleTest(APIViewTestCases.APIViewTestCase):
 class ModuleTest(APIViewTestCases.APIViewTestCase):
     model = Module
     model = Module

+ 33 - 10
netbox/utilities/api.py

@@ -11,7 +11,7 @@ from django.urls import reverse
 from django.utils.module_loading import import_string
 from django.utils.module_loading import import_string
 from django.utils.translation import gettext_lazy as _
 from django.utils.translation import gettext_lazy as _
 from rest_framework.permissions import BasePermission
 from rest_framework.permissions import BasePermission
-from rest_framework.serializers import Serializer
+from rest_framework.serializers import ListSerializer, Serializer
 from rest_framework.views import get_view_name as drf_get_view_name
 from rest_framework.views import get_view_name as drf_get_view_name
 
 
 from extras.constants import HTTP_CONTENT_TYPE_JSON
 from extras.constants import HTTP_CONTENT_TYPE_JSON
@@ -98,6 +98,30 @@ def get_view_name(view):
     return drf_get_view_name(view)
     return drf_get_view_name(view)
 
 
 
 
+def _get_nested_serializer(serializer_field):
+    """
+    Return the nested serializer instance for a declared serializer field.
+    """
+    if isinstance(serializer_field, ListSerializer):
+        serializer_field = serializer_field.child
+
+    if isinstance(serializer_field, Serializer) and hasattr(serializer_field, 'nested'):
+        return serializer_field
+
+    return None
+
+
+def _get_serializer_fields(serializer: Serializer):
+    """
+    Return the effective field names for a serializer instance, honoring any
+    field-level fields=/omit= overrides.
+    """
+    fields = getattr(serializer, '_include_fields', None) or serializer.Meta.fields
+    omit = getattr(serializer, '_omit_fields', []) or []
+
+    return [field_name for field_name in fields if field_name not in omit]
+
+
 def get_prefetches_for_serializer(serializer_class, fields=None, omit=None):
 def get_prefetches_for_serializer(serializer_class, fields=None, omit=None):
     """
     """
     Compile and return a list of fields which should be prefetched on the queryset for a serializer.
     Compile and return a list of fields which should be prefetched on the queryset for a serializer.
@@ -119,7 +143,7 @@ def get_prefetches_for_serializer(serializer_class, fields=None, omit=None):
 
 
         # Determine the name of the model field referenced by the serializer field
         # Determine the name of the model field referenced by the serializer field
         model_field_name = field_name
         model_field_name = field_name
-        if serializer_field and serializer_field.source:
+        if serializer_field and getattr(serializer_field, 'source', None):
             model_field_name = serializer_field.source
             model_field_name = serializer_field.source
 
 
         # If the serializer field does not map to a discrete model field, skip it.
         # If the serializer field does not map to a discrete model field, skip it.
@@ -130,14 +154,13 @@ def get_prefetches_for_serializer(serializer_class, fields=None, omit=None):
         except FieldDoesNotExist:
         except FieldDoesNotExist:
             continue
             continue
 
 
-        # If this field is represented by a nested serializer, recurse to resolve prefetches
-        # for the related object.
-        if serializer_field:
-            if issubclass(type(serializer_field), Serializer):
-                # Determine which fields to prefetch for the nested object
-                subfields = serializer_field.Meta.brief_fields if serializer_field.nested else None
-                for subfield in get_prefetches_for_serializer(type(serializer_field), subfields):
-                    prefetch_fields.append(f'{field_name}__{subfield}')
+        # If this field is represented by a nested serializer, recurse to resolve
+        # prefetches for the related object, honoring any field-level fields=/omit=
+        # constraints set on that serializer field instance.
+        if nested_serializer := _get_nested_serializer(serializer_field):
+            subfields = _get_serializer_fields(nested_serializer)
+            for subfield in get_prefetches_for_serializer(type(nested_serializer), fields=subfields):
+                prefetch_fields.append(f'{field.name}__{subfield}')
 
 
     return prefetch_fields
     return prefetch_fields
 
 

+ 23 - 0
netbox/utilities/testing/utils.py

@@ -12,6 +12,7 @@ from core.models import ObjectType
 from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site
 from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site
 from extras.choices import CustomFieldTypeChoices
 from extras.choices import CustomFieldTypeChoices
 from extras.models import CustomField, Tag
 from extras.models import CustomField, Tag
+from ipam.models import IPAddress
 from users.models import User
 from users.models import User
 from virtualization.models import Cluster, ClusterType, VirtualMachine
 from virtualization.models import Cluster, ClusterType, VirtualMachine
 
 
@@ -65,6 +66,28 @@ def create_test_virtualmachine(name):
     return virtual_machine
     return virtual_machine
 
 
 
 
+def create_test_nat_ip_pair(
+    real_address='10.0.0.10/32', nat_address='198.51.100.10/32', inside_interface=None, outside_interface=None
+):
+    """
+    Convenience method for creating an inside IP and its NAT outside IP.
+
+    Optionally, assign either address to an Interface or VMInterface.
+    Returns (real_ip, nat_ip).
+    """
+    real_ip = IPAddress(address=real_address)
+    if inside_interface is not None:
+        real_ip.assigned_object = inside_interface
+    real_ip.save()
+
+    nat_ip = IPAddress(address=nat_address, nat_inside=real_ip)
+    if outside_interface is not None:
+        nat_ip.assigned_object = outside_interface
+    nat_ip.save()
+
+    return real_ip, nat_ip
+
+
 def create_test_user(username='testuser', permissions=None):
 def create_test_user(username='testuser', permissions=None):
     """
     """
     Create a User with the given permissions.
     Create a User with the given permissions.

+ 81 - 1
netbox/utilities/tests/test_api.py

@@ -8,8 +8,9 @@ from dcim.models import Region, Site
 from extras.choices import CustomFieldTypeChoices
 from extras.choices import CustomFieldTypeChoices
 from extras.models import CustomField
 from extras.models import CustomField
 from ipam.models import VLAN
 from ipam.models import VLAN
+from netbox.api.serializers import BaseModelSerializer
 from netbox.config import get_config
 from netbox.config import get_config
-from utilities.api import get_view_name
+from utilities.api import get_prefetches_for_serializer, get_view_name
 from utilities.testing import APITestCase, disable_warnings
 from utilities.testing import APITestCase, disable_warnings
 
 
 
 
@@ -394,3 +395,82 @@ class GetViewNameTestCase(TestCase):
 
 
         name = get_view_name(view)
         name = get_view_name(view)
         self.assertEqual(name, 'Mock List')
         self.assertEqual(name, 'Mock List')
+
+
+class GetPrefetchesForSerializerTestCase(TestCase):
+
+    def test_nested_serializer_honors_explicit_fields(self):
+        class RegionSerializer(BaseModelSerializer):
+            class Meta:
+                model = Region
+                fields = ('id', 'name', 'parent')
+                brief_fields = ('id', 'name')
+
+        class SiteSerializer(BaseModelSerializer):
+            region = RegionSerializer(nested=True, fields=('id', 'parent'))
+
+            class Meta:
+                model = Site
+                fields = ('id', 'name', 'region')
+
+        self.assertListEqual(
+            get_prefetches_for_serializer(SiteSerializer),
+            ['region', 'region__parent'],
+        )
+
+    def test_nested_serializer_honors_explicit_omit(self):
+        class RegionSerializer(BaseModelSerializer):
+            class Meta:
+                model = Region
+                fields = ('id', 'name', 'parent')
+                brief_fields = ('id', 'name')
+
+        class SiteSerializer(BaseModelSerializer):
+            region = RegionSerializer(nested=True, omit=('name',))
+
+            class Meta:
+                model = Site
+                fields = ('id', 'name', 'region')
+
+        self.assertListEqual(
+            get_prefetches_for_serializer(SiteSerializer),
+            ['region', 'region__parent'],
+        )
+
+    def test_many_nested_serializer_honors_explicit_fields(self):
+        class SiteSerializer(BaseModelSerializer):
+            class Meta:
+                model = Site
+                fields = ('id', 'name', 'region')
+                brief_fields = ('id', 'name')
+
+        class RegionSerializer(BaseModelSerializer):
+            sites = SiteSerializer(nested=True, many=True, fields=('id', 'region'))
+
+            class Meta:
+                model = Region
+                fields = ('id', 'name', 'sites')
+
+        self.assertListEqual(
+            get_prefetches_for_serializer(RegionSerializer),
+            ['sites', 'sites__region'],
+        )
+
+    def test_nested_serializer_uses_source_for_prefetch_path(self):
+        class RegionSerializer(BaseModelSerializer):
+            class Meta:
+                model = Region
+                fields = ('id', 'name', 'parent')
+                brief_fields = ('id', 'name')
+
+        class SiteSerializer(BaseModelSerializer):
+            region_detail = RegionSerializer(source='region', nested=True, fields=('id', 'parent'))
+
+            class Meta:
+                model = Site
+                fields = ('id', 'name', 'region_detail')
+
+        self.assertListEqual(
+            get_prefetches_for_serializer(SiteSerializer),
+            ['region', 'region__parent'],
+        )

+ 19 - 5
netbox/virtualization/api/serializers_/virtualmachines.py

@@ -1,8 +1,7 @@
 from drf_spectacular.utils import extend_schema_field
 from drf_spectacular.utils import extend_schema_field
 from rest_framework import serializers
 from rest_framework import serializers
 
 
-from dcim.api.serializers_.device_components import MACAddressSerializer
-from dcim.api.serializers_.devices import DeviceSerializer
+from dcim.api.serializers_.devices import DeviceSerializer, MACAddressSerializer
 from dcim.api.serializers_.platforms import PlatformSerializer
 from dcim.api.serializers_.platforms import PlatformSerializer
 from dcim.api.serializers_.roles import DeviceRoleSerializer
 from dcim.api.serializers_.roles import DeviceRoleSerializer
 from dcim.api.serializers_.sites import SiteSerializer
 from dcim.api.serializers_.sites import SiteSerializer
@@ -58,9 +57,24 @@ class VirtualMachineSerializer(PrimaryModelSerializer):
     role = DeviceRoleSerializer(nested=True, required=False, allow_null=True)
     role = DeviceRoleSerializer(nested=True, required=False, allow_null=True)
     tenant = TenantSerializer(nested=True, required=False, allow_null=True, default=None)
     tenant = TenantSerializer(nested=True, required=False, allow_null=True, default=None)
     platform = PlatformSerializer(nested=True, required=False, allow_null=True)
     platform = PlatformSerializer(nested=True, required=False, allow_null=True)
-    primary_ip = IPAddressSerializer(nested=True, read_only=True, allow_null=True)
-    primary_ip4 = IPAddressSerializer(nested=True, required=False, allow_null=True)
-    primary_ip6 = IPAddressSerializer(nested=True, required=False, allow_null=True)
+    primary_ip = IPAddressSerializer(
+        nested=True,
+        read_only=True,
+        allow_null=True,
+        fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'],
+    )
+    primary_ip4 = IPAddressSerializer(
+        nested=True,
+        required=False,
+        allow_null=True,
+        fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'],
+    )
+    primary_ip6 = IPAddressSerializer(
+        nested=True,
+        required=False,
+        allow_null=True,
+        fields=[*IPAddressSerializer.Meta.brief_fields, 'nat_inside', 'nat_outside'],
+    )
     config_template = ConfigTemplateSerializer(nested=True, required=False, allow_null=True, default=None)
     config_template = ConfigTemplateSerializer(nested=True, required=False, allow_null=True, default=None)
 
 
     # Counter fields
     # Counter fields

+ 52 - 0
netbox/virtualization/tests/test_api.py

@@ -18,6 +18,7 @@ from utilities.testing import (
     APITestCase,
     APITestCase,
     APIViewTestCases,
     APIViewTestCases,
     create_test_device,
     create_test_device,
+    create_test_nat_ip_pair,
     create_test_virtualmachine,
     create_test_virtualmachine,
     disable_logging,
     disable_logging,
 )
 )
@@ -505,6 +506,57 @@ class VirtualMachineTest(APIViewTestCases.APIViewTestCase):
         response = self.client.post(url, {}, format='json', HTTP_AUTHORIZATION=token_header)
         response = self.client.post(url, {}, format='json', HTTP_AUTHORIZATION=token_header)
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
 
 
+    def test_list_object_includes_nat_inside_on_primary_ip(self):
+        virtualmachine = create_test_virtualmachine('natted-vm')
+        interface = VMInterface.objects.create(virtual_machine=virtualmachine, name='eth0')
+
+        real_ip, nat_ip = create_test_nat_ip_pair(
+            real_address='10.0.1.10/32',
+            nat_address='198.51.100.20/32',
+            inside_interface=interface,
+        )
+
+        virtualmachine.primary_ip4 = nat_ip
+        virtualmachine.save()
+
+        self.add_permissions('virtualization.view_virtualmachine', 'ipam.view_ipaddress')
+        response = self.client.get(f'{self._get_list_url()}?id={virtualmachine.pk}', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        result = response.data['results'][0]
+        for field in ('primary_ip', 'primary_ip4'):
+            self.assertEqual(result[field]['address'], str(nat_ip.address))
+            self.assertEqual(result[field]['nat_inside']['address'], str(real_ip.address))
+            self.assertEqual(result[field]['nat_outside'], [])
+
+    def test_get_object_includes_nat_outside_on_primary_ip(self):
+        virtualmachine = create_test_virtualmachine('real-ip-vm')
+        interface = VMInterface.objects.create(virtual_machine=virtualmachine, name='eth0')
+
+        real_ip, nat_ip = create_test_nat_ip_pair(
+            real_address='10.0.1.11/32',
+            nat_address='198.51.100.21/32',
+            inside_interface=interface,
+        )
+
+        virtualmachine.primary_ip4 = real_ip
+        virtualmachine.save()
+
+        self.add_permissions('virtualization.view_virtualmachine', 'ipam.view_ipaddress')
+        response = self.client.get(
+            f'{self._get_detail_url(virtualmachine)}?exclude=config_context',
+            **self.header,
+        )
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        for field in ('primary_ip', 'primary_ip4'):
+            self.assertEqual(response.data[field]['address'], str(real_ip.address))
+            self.assertIsNone(response.data[field]['nat_inside'])
+            self.assertCountEqual(
+                [ip['address'] for ip in response.data[field]['nat_outside']],
+                [str(nat_ip.address)],
+            )
+
 
 
 class VMInterfaceTest(APIViewTestCases.APIViewTestCase):
 class VMInterfaceTest(APIViewTestCases.APIViewTestCase):
     model = VMInterface
     model = VMInterface