Przeglądaj źródła

Add JournalEntry tests

Jeremy Stretch 4 lat temu
rodzic
commit
f2c079de87

+ 7 - 6
netbox/extras/api/serializers.py

@@ -203,12 +203,13 @@ class JournalEntrySerializer(ValidatedModelSerializer):
     def validate(self, data):
     def validate(self, data):
 
 
         # Validate that the parent object exists
         # Validate that the parent object exists
-        try:
-            data['content_type'].get_object_for_this_type(id=data['object_id'])
-        except ObjectDoesNotExist:
-            raise serializers.ValidationError(
-                "Invalid parent object: {} ID {}".format(data['content_type'], data['object_id'])
-            )
+        if 'assigned_object_type' in data and 'assigned_object_id' in data:
+            try:
+                data['assigned_object_type'].get_object_for_this_type(id=data['assigned_object_id'])
+            except ObjectDoesNotExist:
+                raise serializers.ValidationError(
+                    f"Invalid assigned_object: {data['assigned_object_type']} ID {data['assigned_object_id']}"
+                )
 
 
         # Enforce model validation
         # Enforce model validation
         super().validate(data)
         super().validate(data)

+ 51 - 0
netbox/extras/tests/test_api.py

@@ -1,6 +1,7 @@
 import datetime
 import datetime
 from unittest import skipIf
 from unittest import skipIf
 
 
+from django.contrib.auth.models import User
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.test import override_settings
 from django.test import override_settings
 from django.urls import reverse
 from django.urls import reverse
@@ -309,6 +310,56 @@ class ImageAttachmentTest(
         ImageAttachment.objects.bulk_create(image_attachments)
         ImageAttachment.objects.bulk_create(image_attachments)
 
 
 
 
+class JournalEntryTest(APIViewTestCases.APIViewTestCase):
+    model = JournalEntry
+    brief_fields = ['created', 'display', 'id', 'url']
+    bulk_update_data = {
+        'comments': 'Overwritten',
+    }
+
+    @classmethod
+    def setUpTestData(cls):
+        user = User.objects.first()
+        site = Site.objects.create(name='Site 1', slug='site-1')
+
+        journal_entries = (
+            JournalEntry(
+                created_by=user,
+                assigned_object=site,
+                comments='Fourth entry',
+            ),
+            JournalEntry(
+                created_by=user,
+                assigned_object=site,
+                comments='Fifth entry',
+            ),
+            JournalEntry(
+                created_by=user,
+                assigned_object=site,
+                comments='Sixth entry',
+            ),
+        )
+        JournalEntry.objects.bulk_create(journal_entries)
+
+        cls.create_data = [
+            {
+                'assigned_object_type': 'dcim.site',
+                'assigned_object_id': site.pk,
+                'comments': 'First entry',
+            },
+            {
+                'assigned_object_type': 'dcim.site',
+                'assigned_object_id': site.pk,
+                'comments': 'Second entry',
+            },
+            {
+                'assigned_object_type': 'dcim.site',
+                'assigned_object_id': site.pk,
+                'comments': 'Third entry',
+            },
+        ]
+
+
 class ConfigContextTest(APIViewTestCases.APIViewTestCase):
 class ConfigContextTest(APIViewTestCases.APIViewTestCase):
     model = ConfigContext
     model = ConfigContext
     brief_fields = ['display', 'id', 'name', 'url']
     brief_fields = ['display', 'id', 'name', 'url']

+ 84 - 0
netbox/extras/tests/test_filters.py

@@ -255,6 +255,90 @@ class ImageAttachmentTestCase(TestCase):
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
 
 
 
 
+class JournalEntryTestCase(TestCase):
+    queryset = JournalEntry.objects.all()
+    filterset = JournalEntryFilterSet
+
+    @classmethod
+    def setUpTestData(cls):
+        sites = (
+            Site(name='Site 1', slug='site-1'),
+            Site(name='Site 2', slug='site-2'),
+        )
+        Site.objects.bulk_create(sites)
+
+        racks = (
+            Rack(name='Rack 1', site=sites[0]),
+            Rack(name='Rack 2', site=sites[1]),
+        )
+        Rack.objects.bulk_create(racks)
+
+        users = (
+            User(username='Alice'),
+            User(username='Bob'),
+            User(username='Charlie'),
+        )
+        User.objects.bulk_create(users)
+
+        journal_entries = (
+            JournalEntry(
+                assigned_object=sites[0],
+                created_by=users[0],
+                comments='New journal entry'
+            ),
+            JournalEntry(
+                assigned_object=sites[0],
+                created_by=users[1],
+                comments='New journal entry'
+            ),
+            JournalEntry(
+                assigned_object=sites[1],
+                created_by=users[2],
+                comments='New journal entry'
+            ),
+            JournalEntry(
+                assigned_object=racks[0],
+                created_by=users[0],
+                comments='New journal entry'
+            ),
+            JournalEntry(
+                assigned_object=racks[0],
+                created_by=users[1],
+                comments='New journal entry'
+            ),
+            JournalEntry(
+                assigned_object=racks[1],
+                created_by=users[2],
+                comments='New journal entry'
+            ),
+        )
+        JournalEntry.objects.bulk_create(journal_entries)
+
+    def test_id(self):
+        params = {'id': self.queryset.values_list('pk', flat=True)[:2]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+    def test_created_by(self):
+        users = User.objects.filter(username__in=['Alice', 'Bob'])
+        params = {'created_by': [users[0].username, users[1].username]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+        params = {'created_by_id': [users[0].pk, users[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+
+    def test_assigned_object_type(self):
+        params = {'assigned_object_type': 'dcim.site'}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
+        params = {'assigned_object_type_id': ContentType.objects.get(app_label='dcim', model='site').pk}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
+
+    def test_assigned_object(self):
+        params = {
+            'assigned_object_type': 'dcim.site',
+            'assigned_object_id': [Site.objects.first().pk],
+        }
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+
 class ConfigContextTestCase(TestCase):
 class ConfigContextTestCase(TestCase):
     queryset = ConfigContext.objects.all()
     queryset = ConfigContext.objects.all()
     filterset = ConfigContextFilterSet
     filterset = ConfigContextFilterSet

+ 36 - 2
netbox/extras/tests/test_views.py

@@ -3,12 +3,11 @@ import uuid
 
 
 from django.contrib.auth.models import User
 from django.contrib.auth.models import User
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
-from django.test import override_settings
 from django.urls import reverse
 from django.urls import reverse
 
 
 from dcim.models import Site
 from dcim.models import Site
 from extras.choices import ObjectChangeActionChoices
 from extras.choices import ObjectChangeActionChoices
-from extras.models import ConfigContext, CustomLink, ObjectChange, Tag
+from extras.models import ConfigContext, CustomLink, JournalEntry, ObjectChange, Tag
 from utilities.testing import ViewTestCases, TestCase
 from utilities.testing import ViewTestCases, TestCase
 
 
 
 
@@ -128,6 +127,41 @@ class ObjectChangeTestCase(TestCase):
         self.assertHttpStatus(response, 200)
         self.assertHttpStatus(response, 200)
 
 
 
 
+class JournalEntryTestCase(
+    # ViewTestCases.GetObjectViewTestCase,
+    ViewTestCases.CreateObjectViewTestCase,
+    ViewTestCases.EditObjectViewTestCase,
+    ViewTestCases.DeleteObjectViewTestCase,
+    ViewTestCases.ListObjectsViewTestCase,
+    # ViewTestCases.BulkEditObjectsViewTestCase,
+    # ViewTestCases.BulkDeleteObjectsViewTestCase
+):
+    model = JournalEntry
+
+    @classmethod
+    def setUpTestData(cls):
+        site_ct = ContentType.objects.get_for_model(Site)
+
+        site = Site.objects.create(name='Site 1', slug='site-1')
+        user = User.objects.create(username='User 1')
+
+        JournalEntry.objects.bulk_create((
+            JournalEntry(assigned_object=site, created_by=user, comments='First entry'),
+            JournalEntry(assigned_object=site, created_by=user, comments='Second entry'),
+            JournalEntry(assigned_object=site, created_by=user, comments='Third entry'),
+        ))
+
+        cls.form_data = {
+            'assigned_object_type': site_ct.pk,
+            'assigned_object_id': site.pk,
+            'comments': 'A new entry',
+        }
+
+        cls.bulk_edit_data = {
+            'comments': 'Overwritten',
+        }
+
+
 class CustomLinkTest(TestCase):
 class CustomLinkTest(TestCase):
     user_permissions = ['dcim.view_site']
     user_permissions = ['dcim.view_site']
 
 

+ 2 - 0
netbox/extras/views.py

@@ -304,6 +304,8 @@ class JournalEntryEditView(generic.ObjectEditView):
         return obj
         return obj
 
 
     def get_return_url(self, request, instance):
     def get_return_url(self, request, instance):
+        if not instance.assigned_object:
+            return reverse('extras:journalentry_list')
         obj = instance.assigned_object
         obj = instance.assigned_object
         viewname = f'{obj._meta.app_label}:{obj._meta.model_name}_journal'
         viewname = f'{obj._meta.app_label}:{obj._meta.model_name}_journal'
         return reverse(viewname, kwargs={'pk': obj.pk})
         return reverse(viewname, kwargs={'pk': obj.pk})