Просмотр исходного кода

Merge pull request #21806 from netbox-community/21771-rest-api-add-remove-tags

Closes #21771: Add `add_tags` & `remove_tags` fields for taggable objects
bctiemann 3 дней назад
Родитель
Сommit
1277bb6138

+ 219 - 2
netbox/dcim/tests/test_api.py

@@ -5,16 +5,17 @@ from django.urls import reverse
 from django.utils.translation import gettext as _
 from rest_framework import status
 
+from core.models import ObjectType
 from dcim.choices import *
 from dcim.constants import *
 from dcim.models import *
-from extras.models import ConfigTemplate
+from extras.models import ConfigTemplate, Tag
 from ipam.choices import VLANQinQRoleChoices
 from ipam.models import ASN, RIR, VLAN, VRF
 from netbox.api.serializers import GenericObjectSerializer
 from tenancy.models import Tenant
 from users.constants import TOKEN_PREFIX
-from users.models import Token, User
+from users.models import ObjectPermission, Token, User
 from utilities.testing import APITestCase, APIViewTestCases, create_test_device, disable_logging
 from virtualization.models import Cluster, ClusterType
 from wireless.choices import WirelessChannelChoices
@@ -195,6 +196,222 @@ class SiteTest(APIViewTestCases.APIViewTestCase):
             },
         ]
 
+    def test_add_tags(self):
+        """
+        Add tags to an existing object via the add_tags field.
+        """
+        site = Site.objects.first()
+        tags = Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+            Tag(name='Charlie', slug='charlie'),
+        ))
+        site.tags.set([tags[0], tags[1]])
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'add_tags': [{'name': 'Charlie'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        # Verify all three tags are now assigned
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Alpha', 'Bravo', 'Charlie'])
+
+        # Verify add_tags and remove_tags are not in the response
+        self.assertNotIn('add_tags', response.data)
+        self.assertNotIn('remove_tags', response.data)
+        self.assertIn('tags', response.data)
+
+    def test_remove_tags(self):
+        """
+        Remove tags from an existing object via the remove_tags field.
+        """
+        site = Site.objects.first()
+        tags = Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+            Tag(name='Charlie', slug='charlie'),
+        ))
+        site.tags.set(tags)
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'remove_tags': [{'name': 'Charlie'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        # Verify only Alpha and Bravo remain
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Alpha', 'Bravo'])
+
+    def test_remove_tags_not_assigned(self):
+        """
+        Removing a tag that is not assigned should not raise an error.
+        """
+        site = Site.objects.first()
+        tags = Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+            Tag(name='Charlie', slug='charlie'),
+        ))
+        site.tags.set([tags[0], tags[1]])
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'remove_tags': [{'name': 'Charlie'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        # Tags should be unchanged
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Alpha', 'Bravo'])
+
+    def test_add_and_remove_tags(self):
+        """
+        Add and remove tags in the same request.
+        """
+        site = Site.objects.first()
+        tags = Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+            Tag(name='Charlie', slug='charlie'),
+        ))
+        site.tags.set([tags[0], tags[1]])
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'add_tags': [{'name': 'Charlie'}],
+            'remove_tags': [{'name': 'Alpha'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_200_OK)
+
+        # Verify Bravo and Charlie remain
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Bravo', 'Charlie'])
+
+    def test_tags_with_add_tags_error(self):
+        """
+        Specifying tags together with add_tags or remove_tags should raise a validation error.
+        """
+        site = Site.objects.first()
+        Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+        ))
+
+        # Grant change permission
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'tags': [{'name': 'Alpha'}],
+            'add_tags': [{'name': 'Bravo'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
+
+    def test_create_with_add_tags(self):
+        """
+        Create a new object using add_tags.
+        """
+        Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+        ))
+
+        obj_perm = ObjectPermission(name='Test permission', actions=['add'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        data = {
+            'name': 'Site 10',
+            'slug': 'site-10',
+            'add_tags': [{'name': 'Alpha'}, {'name': 'Bravo'}],
+        }
+        response = self.client.post(self._get_list_url(), data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_201_CREATED)
+
+        site = Site.objects.get(pk=response.data['id'])
+        tag_names = sorted(site.tags.values_list('name', flat=True))
+        self.assertEqual(tag_names, ['Alpha', 'Bravo'])
+
+    def test_create_with_remove_tags_error(self):
+        """
+        Using remove_tags when creating a new object should raise a validation error.
+        """
+        Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+        ))
+
+        obj_perm = ObjectPermission(name='Test permission', actions=['add'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        data = {
+            'name': 'Site 10',
+            'slug': 'site-10',
+            'remove_tags': [{'name': 'Alpha'}],
+        }
+        response = self.client.post(self._get_list_url(), data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
+
+    def test_add_and_remove_same_tag_error(self):
+        """
+        Including the same tag in both add_tags and remove_tags should raise a validation error.
+        """
+        site = Site.objects.first()
+        Tag.objects.bulk_create((
+            Tag(name='Alpha', slug='alpha'),
+            Tag(name='Bravo', slug='bravo'),
+        ))
+
+        obj_perm = ObjectPermission(name='Test permission', actions=['change'])
+        obj_perm.save()
+        obj_perm.users.add(self.user)
+        obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))
+
+        url = self._get_detail_url(site)
+        data = {
+            'add_tags': [{'name': 'Alpha'}, {'name': 'Bravo'}],
+            'remove_tags': [{'name': 'Alpha'}],
+        }
+        response = self.client.patch(url, data, format='json', **self.header)
+        self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
+
 
 class LocationTest(APIViewTestCases.APIViewTestCase):
     model = Location

+ 68 - 0
netbox/netbox/api/serializers/features.py

@@ -30,17 +30,78 @@ class TaggableModelSerializer(serializers.Serializer):
     on create() and update().
     """
     tags = NestedTagSerializer(many=True, required=False)
+    add_tags = NestedTagSerializer(many=True, required=False, write_only=True)
+    remove_tags = NestedTagSerializer(many=True, required=False, write_only=True)
+
+    def to_internal_value(self, data):
+        ret = super().to_internal_value(data)
+
+        # Workaround to bypass requirement to include add_tags/remove_tags in Meta.fields on every serializer
+        if type(data) is dict:
+            tag_serializer = NestedTagSerializer(many=True)
+            for field_name in ('add_tags', 'remove_tags'):
+                if field_name in data:
+                    ret[field_name] = tag_serializer.to_internal_value(data[field_name])
+
+        return ret
+
+    def validate(self, data):
+        # Skip validation for nested serializer representations (e.g. when used as a related field)
+        if type(data) is not dict:
+            return super().validate(data)
+
+        if data.get('tags') and (data.get('add_tags') or data.get('remove_tags')):
+            raise serializers.ValidationError({
+                'tags': 'Cannot specify "tags" together with "add_tags" or "remove_tags".'
+            })
+
+        if self.instance is None and data.get('remove_tags'):
+            raise serializers.ValidationError({
+                'remove_tags': 'Cannot use "remove_tags" when creating a new object.'
+            })
+
+        if data.get('add_tags') and data.get('remove_tags'):
+            add_pks = {t.pk for t in data['add_tags']}
+            remove_pks = {t.pk for t in data['remove_tags']}
+            overlap = [t for t in data['add_tags'] if t.pk in (add_pks & remove_pks)]
+            if overlap:
+                raise serializers.ValidationError({
+                    'remove_tags':
+                        f'Tags may not be present in both "add_tags" and "remove_tags": '
+                        f'{", ".join(t.name for t in overlap)}'
+                })
+
+        # Pop add_tags/remove_tags before calling super() to prevent them from being passed
+        # to the model constructor during ValidatedModelSerializer validation
+        add_tags = data.pop('add_tags', None)
+        remove_tags = data.pop('remove_tags', None)
+
+        data = super().validate(data)
+
+        # Restore for use in create()/update()
+        if add_tags is not None:
+            data['add_tags'] = add_tags
+        if remove_tags is not None:
+            data['remove_tags'] = remove_tags
+
+        return data
 
     def create(self, validated_data):
         tags = validated_data.pop('tags', None)
+        add_tags = validated_data.pop('add_tags', None)
+        validated_data.pop('remove_tags', None)
         instance = super().create(validated_data)
 
         if tags is not None:
             return self._save_tags(instance, tags)
+        if add_tags is not None:
+            instance.tags.add(*[t.name for t in add_tags])
         return instance
 
     def update(self, instance, validated_data):
         tags = validated_data.pop('tags', None)
+        add_tags = validated_data.pop('add_tags', None)
+        remove_tags = validated_data.pop('remove_tags', None)
 
         # Cache tags on instance for change logging
         instance._tags = tags or []
@@ -49,6 +110,13 @@ class TaggableModelSerializer(serializers.Serializer):
 
         if tags is not None:
             return self._save_tags(instance, tags)
+        if add_tags is not None:
+            instance.tags.add(*[t.name for t in add_tags])
+        if remove_tags is not None:
+            instance.tags.remove(*[t.name for t in remove_tags])
+        if add_tags is not None or remove_tags is not None:
+            instance._tags = instance.tags.all()
+
         return instance
 
     def _save_tags(self, instance, tags):

+ 2 - 2
netbox/utilities/testing/api.py

@@ -286,7 +286,7 @@ class APIViewTestCases:
             self.assertEqual(self._get_queryset().count(), initial_count + len(self.create_data))
             for i, obj in enumerate(response.data):
                 for field in self.create_data[i]:
-                    if field == 'changelog_message':
+                    if field in ('changelog_message', 'add_tags', 'remove_tags'):
                         # Write-only field
                         continue
                     if field not in self.validation_excluded_fields:
@@ -444,7 +444,7 @@ class APIViewTestCases:
             self.assertHttpStatus(response, status.HTTP_200_OK)
             for i, obj in enumerate(response.data):
                 for field in self.bulk_update_data:
-                    if field == 'changelog_data':
+                    if field in ('changelog_message', 'add_tags', 'remove_tags'):
                         # Write-only field
                         continue
                     self.assertIn(field, obj, f"Bulk update field '{field}' missing from object {i} in response")