Преглед на файлове

Merge pull request #21806 from netbox-community/21771-rest-api-add-remove-tags

Closes #21771: Add `add_tags` & `remove_tags` fields for taggable objects
bctiemann преди 3 дни
родител
ревизия
1277bb6138
променени са 3 файла, в които са добавени 289 реда и са изтрити 4 реда
  1. 219 2
      netbox/dcim/tests/test_api.py
  2. 68 0
      netbox/netbox/api/serializers/features.py
  3. 2 2
      netbox/utilities/testing/api.py

+ 219 - 2
netbox/dcim/tests/test_api.py

@@ -5,16 +5,17 @@ from django.urls import reverse
 from django.utils.translation import gettext as _
 from django.utils.translation import gettext as _
 from rest_framework import status
 from rest_framework import status
 
 
+from core.models import ObjectType
 from dcim.choices import *
 from dcim.choices import *
 from dcim.constants import *
 from dcim.constants import *
 from dcim.models import *
 from dcim.models import *
-from extras.models import ConfigTemplate
+from extras.models import ConfigTemplate, Tag
 from ipam.choices import VLANQinQRoleChoices
 from ipam.choices import VLANQinQRoleChoices
 from ipam.models import ASN, RIR, VLAN, VRF
 from ipam.models import ASN, RIR, VLAN, VRF
 from netbox.api.serializers import GenericObjectSerializer
 from netbox.api.serializers import GenericObjectSerializer
 from tenancy.models import Tenant
 from tenancy.models import Tenant
 from users.constants import TOKEN_PREFIX
 from users.constants import TOKEN_PREFIX
-from users.models import Token, User
+from users.models import ObjectPermission, Token, User
 from utilities.testing import APITestCase, APIViewTestCases, create_test_device, disable_logging
 from utilities.testing import APITestCase, APIViewTestCases, create_test_device, disable_logging
 from virtualization.models import Cluster, ClusterType
 from virtualization.models import Cluster, ClusterType
 from wireless.choices import WirelessChannelChoices
 from wireless.choices import WirelessChannelChoices
@@ -195,6 +196,222 @@ class SiteTest(APIViewTestCases.APIViewTestCase):
             },
             },
         ]
         ]
 
 
+    def test_add_tags(self):
+        """
+        Add tags to an existing object via the add_tags field.
+        """
+        site = Site.objects.first()
+        tags = Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+            Tag(name='Charlie', slug='charlie'),
+        ))
+        site.tags.set([tags[0], tags[1]])
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'add_tags': [{'name': 'Charlie'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        # Verify all three tags are now assigned
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Alpha', 'Bravo', 'Charlie'])
+
+        # Verify add_tags and remove_tags are not in the response
+        self.assertNotIn('add_tags', response.data)
+        self.assertNotIn('remove_tags', response.data)
+        self.assertIn('tags', response.data)
+
+    def test_remove_tags(self):
+        """
+        Remove tags from an existing object via the remove_tags field.
+        """
+        site = Site.objects.first()
+        tags = Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+            Tag(name='Charlie', slug='charlie'),
+        ))
+        site.tags.set(tags)
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'remove_tags': [{'name': 'Charlie'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        # Verify only Alpha and Bravo remain
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Alpha', 'Bravo'])
+
+    def test_remove_tags_not_assigned(self):
+        """
+        Removing a tag that is not assigned should not raise an error.
+        """
+        site = Site.objects.first()
+        tags = Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+            Tag(name='Charlie', slug='charlie'),
+        ))
+        site.tags.set([tags[0], tags[1]])
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'remove_tags': [{'name': 'Charlie'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        # Tags should be unchanged
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Alpha', 'Bravo'])
+
+    def test_add_and_remove_tags(self):
+        """
+        Add and remove tags in the same request.
+        """
+        site = Site.objects.first()
+        tags = Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+            Tag(name='Charlie', slug='charlie'),
+        ))
+        site.tags.set([tags[0], tags[1]])
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'add_tags': [{'name': 'Charlie'}],
+            'remove_tags': [{'name': 'Alpha'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        # Verify Bravo and Charlie remain
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Bravo', 'Charlie'])
+
+    def test_tags_with_add_tags_error(self):
+        """
+        Specifying tags together with add_tags or remove_tags should raise a validation error.
+        """
+        site = Site.objects.first()
+        Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+        ))
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'tags': [{'name': 'Alpha'}],
+            'add_tags': [{'name': 'Bravo'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
+
+    def test_create_with_add_tags(self):
+        """
+        Create a new object using add_tags.
+        """
+        Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+        ))
+
+        obj_perm = ObjectPermission(name='Test permission', actions=['add'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        data = {
+            'name': 'Site 10',
+            'slug': 'site-10',
+            'add_tags': [{'name': 'Alpha'}, {'name': 'Bravo'}],
+        }
+        response = self.client.post(self._get_list_url(), data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_201_CREATED)
+
+        site = Site.objects.get(pk=response.data['id'])
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Alpha', 'Bravo'])
+
+    def test_create_with_remove_tags_error(self):
+        """
+        Using remove_tags when creating a new object should raise a validation error.
+        """
+        Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+        ))
+
+        obj_perm = ObjectPermission(name='Test permission', actions=['add'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        data = {
+            'name': 'Site 10',
+            'slug': 'site-10',
+            'remove_tags': [{'name': 'Alpha'}],
+        }
+        response = self.client.post(self._get_list_url(), data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
+
+    def test_add_and_remove_same_tag_error(self):
+        """
+        Including the same tag in both add_tags and remove_tags should raise a validation error.
+        """
+        site = Site.objects.first()
+        Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+        ))
+
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'add_tags': [{'name': 'Alpha'}, {'name': 'Bravo'}],
+            'remove_tags': [{'name': 'Alpha'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
+
 
 
 class LocationTest(APIViewTestCases.APIViewTestCase):
 class LocationTest(APIViewTestCases.APIViewTestCase):
     model = Location
     model = Location

+ 68 - 0
netbox/netbox/api/serializers/features.py

@@ -30,17 +30,78 @@ class TaggableModelSerializer(serializers.Serializer):
     on create() and update().
     on create() and update().
     """
     """
     tags = NestedTagSerializer(many=True, required=False)
     tags = NestedTagSerializer(many=True, required=False)
+    add_tags = NestedTagSerializer(many=True, required=False, write_only=True)
+    remove_tags = NestedTagSerializer(many=True, required=False, write_only=True)
+
+    def to_internal_value(self, data):
+        ret = super().to_internal_value(data)
+
+        # Workaround to bypass requirement to include add_tags/remove_tags in Meta.fields on every serializer
+        if type(data) is dict:
+            tag_serializer = NestedTagSerializer(many=True)
+            for field_name in ('add_tags', 'remove_tags'):
+                if field_name in data:
+                    ret[field_name] = tag_serializer.to_internal_value(data[field_name])
+
+        return ret
+
+    def validate(self, data):
+        # Skip validation for nested serializer representations (e.g. when used as a related field)
+        if type(data) is not dict:
+            return super().validate(data)
+
+        if data.get('tags') and (data.get('add_tags') or data.get('remove_tags')):
+            raise serializers.ValidationError({
+                'tags': 'Cannot specify "tags" together with "add_tags" or "remove_tags".'
+            })
+
+        if self.instance is None and data.get('remove_tags'):
+            raise serializers.ValidationError({
+                'remove_tags': 'Cannot use "remove_tags" when creating a new object.'
+            })
+
+        if data.get('add_tags') and data.get('remove_tags'):
+            add_pks = {t.pk for t in data['add_tags']}
+            remove_pks = {t.pk for t in data['remove_tags']}
+            overlap = [t for t in data['add_tags'] if t.pk in (add_pks & remove_pks)]
+            if overlap:
+                raise serializers.ValidationError({
+                    'remove_tags':
+                        f'Tags may not be present in both "add_tags" and "remove_tags": '
+                        f'{", ".join(t.name for t in overlap)}'
+                })
+
+        # Pop add_tags/remove_tags before calling super() to prevent them from being passed
+        # to the model constructor during ValidatedModelSerializer validation
+        add_tags = data.pop('add_tags', None)
+        remove_tags = data.pop('remove_tags', None)
+
+        data = super().validate(data)
+
+        # Restore for use in create()/update()
+        if add_tags is not None:
+            data['add_tags'] = add_tags
+        if remove_tags is not None:
+            data['remove_tags'] = remove_tags
+
+        return data
 
 
     def create(self, validated_data):
     def create(self, validated_data):
         tags = validated_data.pop('tags', None)
         tags = validated_data.pop('tags', None)
+        add_tags = validated_data.pop('add_tags', None)
+        validated_data.pop('remove_tags', None)
         instance = super().create(validated_data)
         instance = super().create(validated_data)
 
 
         if tags is not None:
         if tags is not None:
             return self._save_tags(instance, tags)
             return self._save_tags(instance, tags)
+        if add_tags is not None:
+            instance.tags.add(*[t.name for t in add_tags])
         return instance
         return instance
 
 
     def update(self, instance, validated_data):
     def update(self, instance, validated_data):
         tags = validated_data.pop('tags', None)
         tags = validated_data.pop('tags', None)
+        add_tags = validated_data.pop('add_tags', None)
+        remove_tags = validated_data.pop('remove_tags', None)
 
 
         # Cache tags on instance for change logging
         # Cache tags on instance for change logging
         instance._tags = tags or []
         instance._tags = tags or []
@@ -49,6 +110,13 @@ class TaggableModelSerializer(serializers.Serializer):
 
 
         if tags is not None:
         if tags is not None:
             return self._save_tags(instance, tags)
             return self._save_tags(instance, tags)
+        if add_tags is not None:
+            instance.tags.add(*[t.name for t in add_tags])
+        if remove_tags is not None:
+            instance.tags.remove(*[t.name for t in remove_tags])
+        if add_tags is not None or remove_tags is not None:
+            instance._tags = instance.tags.all()
+
         return instance
         return instance
 
 
     def _save_tags(self, instance, tags):
     def _save_tags(self, instance, tags):

+ 2 - 2
netbox/utilities/testing/api.py

@@ -286,7 +286,7 @@ class APIViewTestCases:
             self.assertEqual(self._get_queryset().count(), initial_count + len(self.create_data))
             self.assertEqual(self._get_queryset().count(), initial_count + len(self.create_data))
             for i, obj in enumerate(response.data):
             for i, obj in enumerate(response.data):
                 for field in self.create_data[i]:
                 for field in self.create_data[i]:
-                    if field == 'changelog_message':
+                    if field in ('changelog_message', 'add_tags', 'remove_tags'):
                         # Write-only field
                         # Write-only field
                         continue
                         continue
                     if field not in self.validation_excluded_fields:
                     if field not in self.validation_excluded_fields:
@@ -444,7 +444,7 @@ class APIViewTestCases:
             self.assertHttpStatus(response, status.HTTP_200_OK)
             self.assertHttpStatus(response, status.HTTP_200_OK)
             for i, obj in enumerate(response.data):
             for i, obj in enumerate(response.data):
                 for field in self.bulk_update_data:
                 for field in self.bulk_update_data:
-                    if field == 'changelog_data':
+                    if field in ('changelog_message', 'add_tags', 'remove_tags'):
                         # Write-only field
                         # Write-only field
                         continue
                         continue
                     self.assertIn(field, obj, f"Bulk update field '{field}' missing from object {i} in response")
                     self.assertIn(field, obj, f"Bulk update field '{field}' missing from object {i} in response")