Просмотр исходного кода

Rewrote ObjectChangeMiddleware to remove the curried handle_deleted_object() function

Jeremy Stretch 6 лет назад
Родитель
Сommit
ccb9f7bfe2

+ 41 - 38
netbox/extras/middleware.py

@@ -1,14 +1,15 @@
 import random
 import random
 import threading
 import threading
 import uuid
 import uuid
+from copy import deepcopy
 from datetime import timedelta
 from datetime import timedelta
 
 
 from django.conf import settings
 from django.conf import settings
-from django.db.models.signals import post_delete, post_save
+from django.db.models.signals import pre_delete, post_save
 from django.utils import timezone
 from django.utils import timezone
-from django.utils.functional import curry
 from django_prometheus.models import model_deletes, model_inserts, model_updates
 from django_prometheus.models import model_deletes, model_inserts, model_updates
 
 
+from utilities.querysets import DummyQuerySet
 from .constants import *
 from .constants import *
 from .models import ObjectChange
 from .models import ObjectChange
 from .signals import purge_changelog
 from .signals import purge_changelog
@@ -19,33 +20,34 @@ _thread_locals = threading.local()
 
 
 def handle_changed_object(sender, instance, **kwargs):
 def handle_changed_object(sender, instance, **kwargs):
     """
     """
-    Fires when an object is created or updated
+    Fires when an object is created or updated.
     """
     """
-    # Queue the object and a new ObjectChange for processing once the request completes
-    if hasattr(instance, 'to_objectchange'):
-        action = OBJECTCHANGE_ACTION_CREATE if kwargs['created'] else OBJECTCHANGE_ACTION_UPDATE
-        objectchange = instance.to_objectchange(action)
-        _thread_locals.changed_objects.append(
-            (instance, objectchange)
-        )
+    # Queue a copy of the object for processing once the request completes
+    action = OBJECTCHANGE_ACTION_CREATE if kwargs['created'] else OBJECTCHANGE_ACTION_UPDATE
+    _thread_locals.changed_objects.append(
+        (instance, action)
+    )
 
 
 
 
-def _handle_deleted_object(request, sender, instance, **kwargs):
+def handle_deleted_object(sender, instance, **kwargs):
     """
     """
-    Fires when an object is deleted
+    Fires when an object is deleted.
     """
     """
-    # Record an Object Change
-    if hasattr(instance, 'to_objectchange'):
-        objectchange = instance.to_objectchange(OBJECTCHANGE_ACTION_DELETE)
-        objectchange.user = request.user
-        objectchange.request_id = request.id
-        objectchange.save()
+    # Cache custom fields prior to copying the instance
+    if hasattr(instance, 'cache_custom_fields'):
+        instance.cache_custom_fields()
 
 
-    # Enqueue webhooks
-    enqueue_webhooks(instance, request.user, request.id, OBJECTCHANGE_ACTION_DELETE)
+    # Create a copy of the object being deleted
+    copy = deepcopy(instance)
 
 
-    # Increment metric counters
-    model_deletes.labels(instance._meta.model_name).inc()
+    # Preserve tags
+    if hasattr(instance, 'tags'):
+        copy.tags = DummyQuerySet(instance.tags.all())
+
+    # Queue a copy of the object for processing once the request completes
+    _thread_locals.changed_objects.append(
+        (copy, OBJECTCHANGE_ACTION_DELETE)
+    )
 
 
 
 
 def purge_objectchange_cache(sender, **kwargs):
 def purge_objectchange_cache(sender, **kwargs):
@@ -81,12 +83,9 @@ class ObjectChangeMiddleware(object):
         # the same request.
         # the same request.
         request.id = uuid.uuid4()
         request.id = uuid.uuid4()
 
 
-        # Signals don't include the request context, so we're currying it into the post_delete function ahead of time.
-        handle_deleted_object = curry(_handle_deleted_object, request)
-
         # Connect our receivers to the post_save and post_delete signals.
         # Connect our receivers to the post_save and post_delete signals.
-        post_save.connect(handle_changed_object, dispatch_uid='cache_changed_object')
-        post_delete.connect(handle_deleted_object, dispatch_uid='cache_deleted_object')
+        post_save.connect(handle_changed_object, dispatch_uid='handle_changed_object')
+        pre_delete.connect(handle_deleted_object, dispatch_uid='handle_deleted_object')
 
 
         # Provide a hook for purging the change cache
         # Provide a hook for purging the change cache
         purge_changelog.connect(purge_objectchange_cache)
         purge_changelog.connect(purge_objectchange_cache)
@@ -98,22 +97,26 @@ class ObjectChangeMiddleware(object):
         if not _thread_locals.changed_objects:
         if not _thread_locals.changed_objects:
             return response
             return response
 
 
-        # Create records for any cached objects that were created/updated.
-        for obj, objectchange in _thread_locals.changed_objects:
+        # Create records for any cached objects that were changed.
+        for instance, action in _thread_locals.changed_objects:
 
 
-            # Record the change
-            objectchange.user = request.user
-            objectchange.request_id = request.id
-            objectchange.save()
+            # Record an ObjectChange if applicable
+            if hasattr(instance, 'to_objectchange'):
+                objectchange = instance.to_objectchange(action)
+                objectchange.user = request.user
+                objectchange.request_id = request.id
+                objectchange.save()
 
 
             # Enqueue webhooks
             # Enqueue webhooks
-            enqueue_webhooks(obj, request.user, request.id, objectchange.action)
+            enqueue_webhooks(instance, request.user, request.id, action)
 
 
             # Increment metric counters
             # Increment metric counters
-            if objectchange.action == OBJECTCHANGE_ACTION_CREATE:
-                model_inserts.labels(obj._meta.model_name).inc()
-            elif objectchange.action == OBJECTCHANGE_ACTION_UPDATE:
-                model_updates.labels(obj._meta.model_name).inc()
+            if action == OBJECTCHANGE_ACTION_CREATE:
+                model_inserts.labels(instance._meta.model_name).inc()
+            elif action == OBJECTCHANGE_ACTION_UPDATE:
+                model_updates.labels(instance._meta.model_name).inc()
+            elif action == OBJECTCHANGE_ACTION_DELETE:
+                model_deletes.labels(instance._meta.model_name).inc()
 
 
         # Housekeeping: 1% chance of clearing out expired ObjectChanges. This applies only to requests which result in
         # Housekeeping: 1% chance of clearing out expired ObjectChanges. This applies only to requests which result in
         # one or more changes being logged.
         # one or more changes being logged.

+ 9 - 4
netbox/extras/models.py

@@ -138,16 +138,21 @@ class CustomFieldModel(models.Model):
     class Meta:
     class Meta:
         abstract = True
         abstract = True
 
 
+    def cache_custom_fields(self):
+        """
+        Cache all custom field values for this instance
+        """
+        self._cf = {
+            field.name: value for field, value in self.get_custom_fields().items()
+        }
+
     @property
     @property
     def cf(self):
     def cf(self):
         """
         """
         Name-based CustomFieldValue accessor for use in templates
         Name-based CustomFieldValue accessor for use in templates
         """
         """
         if self._cf is None:
         if self._cf is None:
-            # Cache all custom field values for this instance
-            self._cf = {
-                field.name: value for field, value in self.get_custom_fields().items()
-            }
+            self.cache_custom_fields()
         return self._cf
         return self._cf
 
 
     def get_custom_fields(self):
     def get_custom_fields(self):

+ 54 - 11
netbox/extras/tests/test_changelog.py

@@ -1,33 +1,57 @@
+from django.contrib.contenttypes.models import ContentType
 from django.urls import reverse
 from django.urls import reverse
 from rest_framework import status
 from rest_framework import status
 
 
 from dcim.models import Site
 from dcim.models import Site
-from extras.constants import OBJECTCHANGE_ACTION_CREATE, OBJECTCHANGE_ACTION_UPDATE, OBJECTCHANGE_ACTION_DELETE
-from extras.models import ObjectChange
+from extras.constants import *
+from extras.models import CustomField, CustomFieldValue, ObjectChange
 from utilities.testing import APITestCase
 from utilities.testing import APITestCase
 
 
 
 
 class ChangeLogTest(APITestCase):
 class ChangeLogTest(APITestCase):
 
 
+    def setUp(self):
+
+        super().setUp()
+
+        # Create a custom field on the Site model
+        ct = ContentType.objects.get_for_model(Site)
+        cf = CustomField(
+            type=CF_TYPE_TEXT,
+            name='my_field',
+            required=False
+        )
+        cf.save()
+        cf.obj_type.set([ct])
+
     def test_create_object(self):
     def test_create_object(self):
 
 
         data = {
         data = {
             'name': 'Test Site 1',
             'name': 'Test Site 1',
             'slug': 'test-site-1',
             'slug': 'test-site-1',
+            'custom_fields': {
+                'my_field': 'ABC'
+            },
+            'tags': [
+                'bar', 'foo'
+            ],
         }
         }
 
 
         self.assertEqual(ObjectChange.objects.count(), 0)
         self.assertEqual(ObjectChange.objects.count(), 0)
 
 
         url = reverse('dcim-api:site-list')
         url = reverse('dcim-api:site-list')
         response = self.client.post(url, data, format='json', **self.header)
         response = self.client.post(url, data, format='json', **self.header)
-
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
         self.assertHttpStatus(response, status.HTTP_201_CREATED)
-        self.assertEqual(ObjectChange.objects.count(), 1)
 
 
-        oc = ObjectChange.objects.first()
         site = Site.objects.get(pk=response.data['id'])
         site = Site.objects.get(pk=response.data['id'])
+        oc = ObjectChange.objects.get(
+            changed_object_type=ContentType.objects.get_for_model(Site),
+            changed_object_id=site.pk
+        )
         self.assertEqual(oc.changed_object, site)
         self.assertEqual(oc.changed_object, site)
         self.assertEqual(oc.action, OBJECTCHANGE_ACTION_CREATE)
         self.assertEqual(oc.action, OBJECTCHANGE_ACTION_CREATE)
+        self.assertEqual(oc.object_data['custom_fields'], data['custom_fields'])
+        self.assertListEqual(sorted(oc.object_data['tags']), data['tags'])
 
 
     def test_update_object(self):
     def test_update_object(self):
 
 
@@ -37,26 +61,43 @@ class ChangeLogTest(APITestCase):
         data = {
         data = {
             'name': 'Test Site X',
             'name': 'Test Site X',
             'slug': 'test-site-x',
             'slug': 'test-site-x',
+            'custom_fields': {
+                'my_field': 'DEF'
+            },
+            'tags': [
+                'abc', 'xyz'
+            ],
         }
         }
 
 
         self.assertEqual(ObjectChange.objects.count(), 0)
         self.assertEqual(ObjectChange.objects.count(), 0)
 
 
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
         url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
         response = self.client.put(url, data, format='json', **self.header)
         response = self.client.put(url, data, format='json', **self.header)
-
         self.assertHttpStatus(response, status.HTTP_200_OK)
         self.assertHttpStatus(response, status.HTTP_200_OK)
-        self.assertEqual(ObjectChange.objects.count(), 1)
-        site = Site.objects.get(pk=response.data['id'])
-        self.assertEqual(site.name, data['name'])
 
 
-        oc = ObjectChange.objects.first()
+        site = Site.objects.get(pk=response.data['id'])
+        oc = ObjectChange.objects.get(
+            changed_object_type=ContentType.objects.get_for_model(Site),
+            changed_object_id=site.pk
+        )
         self.assertEqual(oc.changed_object, site)
         self.assertEqual(oc.changed_object, site)
         self.assertEqual(oc.action, OBJECTCHANGE_ACTION_UPDATE)
         self.assertEqual(oc.action, OBJECTCHANGE_ACTION_UPDATE)
+        self.assertEqual(oc.object_data['custom_fields'], data['custom_fields'])
+        self.assertListEqual(sorted(oc.object_data['tags']), data['tags'])
 
 
     def test_delete_object(self):
     def test_delete_object(self):
 
 
-        site = Site(name='Test Site 1', slug='test-site-1')
+        site = Site(
+            name='Test Site 1',
+            slug='test-site-1'
+        )
         site.save()
         site.save()
+        site.tags.add('foo', 'bar')
+        CustomFieldValue.objects.create(
+            field=CustomField.objects.get(name='my_field'),
+            obj=site,
+            value='ABC'
+        )
 
 
         self.assertEqual(ObjectChange.objects.count(), 0)
         self.assertEqual(ObjectChange.objects.count(), 0)
 
 
@@ -70,3 +111,5 @@ class ChangeLogTest(APITestCase):
         self.assertEqual(oc.changed_object, None)
         self.assertEqual(oc.changed_object, None)
         self.assertEqual(oc.object_repr, site.name)
         self.assertEqual(oc.object_repr, site.name)
         self.assertEqual(oc.action, OBJECTCHANGE_ACTION_DELETE)
         self.assertEqual(oc.action, OBJECTCHANGE_ACTION_DELETE)
+        self.assertEqual(oc.object_data['custom_fields'], {'my_field': 'ABC'})
+        self.assertListEqual(sorted(oc.object_data['tags']), ['bar', 'foo'])

+ 9 - 0
netbox/utilities/querysets.py

@@ -0,0 +1,9 @@
+class DummyQuerySet:
+    """
+    A fake QuerySet that can be used to cache relationships to objects that have been deleted.
+    """
+    def __init__(self, queryset):
+        self._cache = [obj for obj in queryset.all()]
+
+    def all(self):
+        return self._cache

+ 1 - 1
netbox/utilities/utils.py

@@ -99,7 +99,7 @@ def serialize_object(obj, extra=None):
     # Include any custom fields
     # Include any custom fields
     if hasattr(obj, 'get_custom_fields'):
     if hasattr(obj, 'get_custom_fields'):
         data['custom_fields'] = {
         data['custom_fields'] = {
-            field.name: str(value) for field, value in obj.get_custom_fields().items()
+            field: str(value) for field, value in obj.cf.items()
         }
         }
 
 
     # Include any tags
     # Include any tags