Просмотр исходного кода

Closes #10851: New staging mechanism (#10890)

* WIP

* Convert checkout() context manager to a class

* Misc cleanup

* Drop unique constraint from Change model

* Extend staging tests

* Misc cleanup

* Incorporate M2M changes

* Don't cancel wipe out creation records when an object is deleted

* Rename Change to StagedChange

* Add documentation for change staging
Jeremy Stretch 3 лет назад
Родитель
Сommit
a5308ea28e

+ 13 - 0
docs/models/extras/branch.md

@@ -0,0 +1,13 @@
+# Branches
+
+A branch is a collection of related [staged changes](./stagedchange.md) that have been prepared for merging into the active database. A branch can be mered by executing its `commit()` method. Deleting a branch will delete all its related changes.
+
+## Fields
+
+### Name
+
+The branch's name.
+
+### User
+
+The user to which the branch belongs (optional).

+ 26 - 0
docs/models/extras/stagedchange.md

@@ -0,0 +1,26 @@
+# Staged Changes
+
+A staged change represents the creation of a new object or the modification or deletion of an existing object to be performed at some future point. Each change must be assigned to a [branch](./branch.md).
+
+Changes can be applied individually via the `apply()` method, however it is recommended to apply changes in bulk using the parent branch's `commit()` method.
+
+## Fields
+
+!!! warning
+    Staged changes are not typically created or manipulated directly, but rather effected through the use of the [`checkout()`](../../plugins/development/staged-changes.md) context manager.
+
+### Branch
+
+The [branch](./branch.md) to which this change belongs.
+
+### Action
+
+The type of action this change represents: `create`, `update`, or `delete`.
+
+### Object
+
+A generic foreign key referencing the existing object to which this change applies.
+
+### Data
+
+JSON representation of the changes being made to the object (not applicable for deletions).

+ 42 - 0
docs/plugins/development/staged-changes.md

@@ -0,0 +1,42 @@
+# Staged Changes
+
+!!! danger "Experimental Feature"
+    This feature is still under active development and considered experimental in nature. Its use in production is strongly discouraged at this time.
+
+!!! note
+    This feature was introduced in NetBox v3.4.
+
+NetBox provides a programmatic API to stage the creation, modification, and deletion of objects without actually committing those changes to the active database. This can be useful for performing a "dry run" of bulk operations, or preparing a set of changes for administrative approval, for example.
+
+To begin staging changes, first create a [branch](../../models/extras/branch.md):
+
+```python
+from extras.models import Branch
+
+branch1 = Branch.objects.create(name='branch1')
+```
+
+Then, activate the branch using the `checkout()` context manager and begin making your changes. This initiates a new database transaction.
+
+```python
+from extras.models import Branch
+from netbox.staging import checkout
+
+branch1 = Branch.objects.get(name='branch1')
+with checkout(branch1):
+    Site.objects.create(name='New Site', slug='new-site')
+    # ...
+```
+
+Upon exiting the context, the database transaction is automatically rolled back and your changes recorded as [staged changes](../../models/extras/stagedchange.md). Re-entering a branch will trigger a new database transaction and automatically apply any staged changes associated with the branch.
+
+To apply the changes within a branch, call the branch's `commit()` method:
+
+```python
+from extras.models import Branch
+
+branch1 = Branch.objects.get(name='branch1')
+branch1.commit()
+```
+
+Committing a branch is an all-or-none operation: Any exceptions will revert the entire set of changes. After successfully committing a branch, all its associated StagedChange objects are automatically deleted (however the branch itself will remain and can be reused).

+ 3 - 0
mkdocs.yml

@@ -131,6 +131,7 @@ nav:
             - REST API: 'plugins/development/rest-api.md'
             - REST API: 'plugins/development/rest-api.md'
             - GraphQL API: 'plugins/development/graphql-api.md'
             - GraphQL API: 'plugins/development/graphql-api.md'
             - Background Tasks: 'plugins/development/background-tasks.md'
             - Background Tasks: 'plugins/development/background-tasks.md'
+            - Staged Changes: 'plugins/development/staged-changes.md'
             - Exceptions: 'plugins/development/exceptions.md'
             - Exceptions: 'plugins/development/exceptions.md'
             - Search: 'plugins/development/search.md'
             - Search: 'plugins/development/search.md'
     - Administration:
     - Administration:
@@ -191,12 +192,14 @@ nav:
             - SiteGroup: 'models/dcim/sitegroup.md'
             - SiteGroup: 'models/dcim/sitegroup.md'
             - VirtualChassis: 'models/dcim/virtualchassis.md'
             - VirtualChassis: 'models/dcim/virtualchassis.md'
         - Extras:
         - Extras:
+            - Branch: 'models/extras/branch.md'
             - ConfigContext: 'models/extras/configcontext.md'
             - ConfigContext: 'models/extras/configcontext.md'
             - CustomField: 'models/extras/customfield.md'
             - CustomField: 'models/extras/customfield.md'
             - CustomLink: 'models/extras/customlink.md'
             - CustomLink: 'models/extras/customlink.md'
             - ExportTemplate: 'models/extras/exporttemplate.md'
             - ExportTemplate: 'models/extras/exporttemplate.md'
             - ImageAttachment: 'models/extras/imageattachment.md'
             - ImageAttachment: 'models/extras/imageattachment.md'
             - JournalEntry: 'models/extras/journalentry.md'
             - JournalEntry: 'models/extras/journalentry.md'
+            - StagedChange: 'models/extras/stagedchange.md'
             - Tag: 'models/extras/tag.md'
             - Tag: 'models/extras/tag.md'
             - Webhook: 'models/extras/webhook.md'
             - Webhook: 'models/extras/webhook.md'
         - IPAM:
         - IPAM:

+ 17 - 0
netbox/extras/choices.py

@@ -182,3 +182,20 @@ class WebhookHttpMethodChoices(ChoiceSet):
         (METHOD_PATCH, 'PATCH'),
         (METHOD_PATCH, 'PATCH'),
         (METHOD_DELETE, 'DELETE'),
         (METHOD_DELETE, 'DELETE'),
     )
     )
+
+
+#
+# Staging
+#
+
+class ChangeActionChoices(ChoiceSet):
+
+    ACTION_CREATE = 'create'
+    ACTION_UPDATE = 'update'
+    ACTION_DELETE = 'delete'
+
+    CHOICES = (
+        (ACTION_CREATE, 'Create'),
+        (ACTION_UPDATE, 'Update'),
+        (ACTION_DELETE, 'Delete'),
+    )

+ 45 - 0
netbox/extras/migrations/0084_staging.py

@@ -0,0 +1,45 @@
+from django.conf import settings
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('contenttypes', '0002_remove_content_type_name'),
+        migrations.swappable_dependency(settings.AUTH_USER_MODEL),
+        ('extras', '0083_savedfilter'),
+    ]
+
+    operations = [
+        migrations.CreateModel(
+            name='Branch',
+            fields=[
+                ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False)),
+                ('created', models.DateTimeField(auto_now_add=True, null=True)),
+                ('last_updated', models.DateTimeField(auto_now=True, null=True)),
+                ('name', models.CharField(max_length=100, unique=True)),
+                ('description', models.CharField(blank=True, max_length=200)),
+                ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL)),
+            ],
+            options={
+                'ordering': ('name',),
+            },
+        ),
+        migrations.CreateModel(
+            name='StagedChange',
+            fields=[
+                ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False)),
+                ('created', models.DateTimeField(auto_now_add=True, null=True)),
+                ('last_updated', models.DateTimeField(auto_now=True, null=True)),
+                ('action', models.CharField(max_length=20)),
+                ('object_id', models.PositiveBigIntegerField(blank=True, null=True)),
+                ('data', models.JSONField(blank=True, null=True)),
+                ('branch', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='staged_changes', to='extras.branch')),
+                ('object_type', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='+', to='contenttypes.contenttype')),
+            ],
+            options={
+                'ordering': ('pk',),
+            },
+        ),
+    ]

+ 3 - 0
netbox/extras/models/__init__.py

@@ -3,9 +3,11 @@ from .configcontexts import ConfigContext, ConfigContextModel
 from .customfields import CustomField
 from .customfields import CustomField
 from .models import *
 from .models import *
 from .search import *
 from .search import *
+from .staging import *
 from .tags import Tag, TaggedItem
 from .tags import Tag, TaggedItem
 
 
 __all__ = (
 __all__ = (
+    'Branch',
     'CachedValue',
     'CachedValue',
     'ConfigContext',
     'ConfigContext',
     'ConfigContextModel',
     'ConfigContextModel',
@@ -20,6 +22,7 @@ __all__ = (
     'Report',
     'Report',
     'SavedFilter',
     'SavedFilter',
     'Script',
     'Script',
+    'StagedChange',
     'Tag',
     'Tag',
     'TaggedItem',
     'TaggedItem',
     'Webhook',
     'Webhook',

+ 114 - 0
netbox/extras/models/staging.py

@@ -0,0 +1,114 @@
+import logging
+
+from django.contrib.auth import get_user_model
+from django.contrib.contenttypes.fields import GenericForeignKey
+from django.contrib.contenttypes.models import ContentType
+from django.db import models, transaction
+
+from extras.choices import ChangeActionChoices
+from netbox.models import ChangeLoggedModel
+from utilities.utils import deserialize_object
+
+__all__ = (
+    'Branch',
+    'StagedChange',
+)
+
+logger = logging.getLogger('netbox.staging')
+
+
+class Branch(ChangeLoggedModel):
+    """
+    A collection of related StagedChanges.
+    """
+    name = models.CharField(
+        max_length=100,
+        unique=True
+    )
+    description = models.CharField(
+        max_length=200,
+        blank=True
+    )
+    user = models.ForeignKey(
+        to=get_user_model(),
+        on_delete=models.SET_NULL,
+        blank=True,
+        null=True
+    )
+
+    class Meta:
+        ordering = ('name',)
+
+    def __str__(self):
+        return f'{self.name} ({self.pk})'
+
+    def merge(self):
+        logger.info(f'Merging changes in branch {self}')
+        with transaction.atomic():
+            for change in self.staged_changes.all():
+                change.apply()
+        self.staged_changes.all().delete()
+
+
+class StagedChange(ChangeLoggedModel):
+    """
+    The prepared creation, modification, or deletion of an object to be applied to the active database at a
+    future point.
+    """
+    branch = models.ForeignKey(
+        to=Branch,
+        on_delete=models.CASCADE,
+        related_name='staged_changes'
+    )
+    action = models.CharField(
+        max_length=20,
+        choices=ChangeActionChoices
+    )
+    object_type = models.ForeignKey(
+        to=ContentType,
+        on_delete=models.CASCADE,
+        related_name='+'
+    )
+    object_id = models.PositiveBigIntegerField(
+        blank=True,
+        null=True
+    )
+    object = GenericForeignKey(
+        ct_field='object_type',
+        fk_field='object_id'
+    )
+    data = models.JSONField(
+        blank=True,
+        null=True
+    )
+
+    class Meta:
+        ordering = ('pk',)
+
+    def __str__(self):
+        action = self.get_action_display()
+        app_label, model_name = self.object_type.natural_key()
+        return f"{action} {app_label}.{model_name} ({self.object_id})"
+
+    @property
+    def model(self):
+        return self.object_type.model_class()
+
+    def apply(self):
+        """
+        Apply the staged create/update/delete action to the database.
+        """
+        if self.action == ChangeActionChoices.ACTION_CREATE:
+            instance = deserialize_object(self.model, self.data, pk=self.object_id)
+            logger.info(f'Creating {self.model._meta.verbose_name} {instance}')
+            instance.save()
+
+        if self.action == ChangeActionChoices.ACTION_UPDATE:
+            instance = deserialize_object(self.model, self.data, pk=self.object_id)
+            logger.info(f'Updating {self.model._meta.verbose_name} {instance}')
+            instance.save()
+
+        if self.action == ChangeActionChoices.ACTION_DELETE:
+            instance = self.model.objects.get(pk=self.object_id)
+            logger.info(f'Deleting {self.model._meta.verbose_name} {instance}')
+            instance.delete()

+ 148 - 0
netbox/netbox/staging.py

@@ -0,0 +1,148 @@
+import logging
+
+from django.contrib.contenttypes.models import ContentType
+from django.db import transaction
+from django.db.models.signals import m2m_changed, pre_delete, post_save
+
+from extras.choices import ChangeActionChoices
+from extras.models import StagedChange
+from utilities.utils import serialize_object
+
+logger = logging.getLogger('netbox.staging')
+
+
+class checkout:
+    """
+    Context manager for staging changes to NetBox objects. Staged changes are saved out-of-band
+    (as Change instances) for application at a later time, without modifying the production
+    database.
+
+        branch = Branch.objects.create(name='my-branch')
+        with checkout(branch):
+            # All changes made herein will be rolled back and stored for later
+
+    Note that invoking the context disabled transaction autocommit to facilitate manual rollbacks,
+    and restores its original value upon exit.
+    """
+    def __init__(self, branch):
+        self.branch = branch
+        self.queue = {}
+
+    def __enter__(self):
+
+        # Disable autocommit to effect a new transaction
+        logger.debug(f"Entering transaction for {self.branch}")
+        self._autocommit = transaction.get_autocommit()
+        transaction.set_autocommit(False)
+
+        # Apply any existing Changes assigned to this Branch
+        staged_changes = self.branch.staged_changes.all()
+        if change_count := staged_changes.count():
+            logger.debug(f"Applying {change_count} pre-staged changes...")
+            for change in staged_changes:
+                change.apply()
+        else:
+            logger.debug("No pre-staged changes found")
+
+        # Connect signal handlers
+        logger.debug("Connecting signal handlers")
+        post_save.connect(self.post_save_handler)
+        m2m_changed.connect(self.post_save_handler)
+        pre_delete.connect(self.pre_delete_handler)
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+
+        # Disconnect signal handlers
+        logger.debug("Disconnecting signal handlers")
+        post_save.disconnect(self.post_save_handler)
+        m2m_changed.disconnect(self.post_save_handler)
+        pre_delete.disconnect(self.pre_delete_handler)
+
+        # Roll back the transaction to return the database to its original state
+        logger.debug("Rolling back database transaction")
+        transaction.rollback()
+        logger.debug(f"Restoring autocommit state ({self._autocommit})")
+        transaction.set_autocommit(self._autocommit)
+
+        # Process queued changes
+        self.process_queue()
+
+    #
+    # Queuing
+    #
+
+    @staticmethod
+    def get_key_for_instance(instance):
+        return ContentType.objects.get_for_model(instance), instance.pk
+
+    def process_queue(self):
+        """
+        Create Change instances for all actions stored in the queue.
+        """
+        if not self.queue:
+            logger.debug(f"No queued changes; aborting")
+            return
+        logger.debug(f"Processing {len(self.queue)} queued changes")
+
+        # Iterate through the in-memory queue, creating Change instances
+        changes = []
+        for key, change in self.queue.items():
+            logger.debug(f'  {key}: {change}')
+            object_type, pk = key
+            action, data = change
+
+            changes.append(StagedChange(
+                branch=self.branch,
+                action=action,
+                object_type=object_type,
+                object_id=pk,
+                data=data
+            ))
+
+        # Save all Change instances to the database
+        StagedChange.objects.bulk_create(changes)
+
+    #
+    # Signal handlers
+    #
+
+    def post_save_handler(self, sender, instance, **kwargs):
+        """
+        Hooks to the post_save signal when a branch is active to queue create and update actions.
+        """
+        key = self.get_key_for_instance(instance)
+        object_type = instance._meta.verbose_name
+
+        # Creating a new object
+        if kwargs.get('created'):
+            logger.debug(f"[{self.branch}] Staging creation of {object_type} {instance} (PK: {instance.pk})")
+            data = serialize_object(instance, resolve_tags=False)
+            self.queue[key] = (ChangeActionChoices.ACTION_CREATE, data)
+            return
+
+        # Ignore pre_* many-to-many actions
+        if 'action' in kwargs and kwargs['action'] not in ('post_add', 'post_remove', 'post_clear'):
+            return
+
+        # Object has already been created/updated in the queue; update its queued representation
+        if key in self.queue:
+            logger.debug(f"[{self.branch}] Updating staged value for {object_type} {instance} (PK: {instance.pk})")
+            data = serialize_object(instance, resolve_tags=False)
+            self.queue[key] = (self.queue[key][0], data)
+            return
+
+        # Modifying an existing object for the first time
+        logger.debug(f"[{self.branch}] Staging changes to {object_type} {instance} (PK: {instance.pk})")
+        data = serialize_object(instance, resolve_tags=False)
+        self.queue[key] = (ChangeActionChoices.ACTION_UPDATE, data)
+
+    def pre_delete_handler(self, sender, instance, **kwargs):
+        """
+        Hooks to the pre_delete signal when a branch is active to queue delete actions.
+        """
+        key = self.get_key_for_instance(instance)
+        object_type = instance._meta.verbose_name
+
+        # Delete an existing object
+        logger.debug(f"[{self.branch}] Staging deletion of {object_type} {instance} (PK: {instance.pk})")
+        self.queue[key] = (ChangeActionChoices.ACTION_DELETE, None)

+ 210 - 0
netbox/netbox/tests/test_staging.py

@@ -0,0 +1,210 @@
+from django.test import TransactionTestCase
+
+from circuits.models import Provider, Circuit, CircuitType
+from extras.choices import ChangeActionChoices
+from extras.models import Branch, StagedChange, Tag
+from ipam.models import ASN, RIR
+from netbox.staging import checkout
+from utilities.testing import create_tags
+
+
+class StagingTestCase(TransactionTestCase):
+
+    def setUp(self):
+        create_tags('Alpha', 'Bravo', 'Charlie')
+
+        rir = RIR.objects.create(name='RIR 1', slug='rir-1')
+        asns = (
+            ASN(asn=65001, rir=rir),
+            ASN(asn=65002, rir=rir),
+            ASN(asn=65003, rir=rir),
+        )
+        ASN.objects.bulk_create(asns)
+
+        providers = (
+            Provider(name='Provider A', slug='provider-a'),
+            Provider(name='Provider B', slug='provider-b'),
+            Provider(name='Provider C', slug='provider-c'),
+        )
+        Provider.objects.bulk_create(providers)
+
+        circuit_type = CircuitType.objects.create(name='Circuit Type 1', slug='circuit-type-1')
+
+        Circuit.objects.bulk_create((
+            Circuit(provider=providers[0], cid='Circuit A1', type=circuit_type),
+            Circuit(provider=providers[0], cid='Circuit A2', type=circuit_type),
+            Circuit(provider=providers[0], cid='Circuit A3', type=circuit_type),
+            Circuit(provider=providers[1], cid='Circuit B1', type=circuit_type),
+            Circuit(provider=providers[1], cid='Circuit B2', type=circuit_type),
+            Circuit(provider=providers[1], cid='Circuit B3', type=circuit_type),
+            Circuit(provider=providers[2], cid='Circuit C1', type=circuit_type),
+            Circuit(provider=providers[2], cid='Circuit C2', type=circuit_type),
+            Circuit(provider=providers[2], cid='Circuit C3', type=circuit_type),
+        ))
+
+    def test_object_creation(self):
+        branch = Branch.objects.create(name='Branch 1')
+        tags = Tag.objects.all()
+        asns = ASN.objects.all()
+
+        with checkout(branch):
+            provider = Provider.objects.create(name='Provider D', slug='provider-d')
+            provider.asns.set(asns)
+            circuit = Circuit.objects.create(provider=provider, cid='Circuit D1', type=CircuitType.objects.first())
+            circuit.tags.set(tags)
+
+            # Sanity-checking
+            self.assertEqual(Provider.objects.count(), 4)
+            self.assertListEqual(list(provider.asns.all()), list(asns))
+            self.assertEqual(Circuit.objects.count(), 10)
+            self.assertListEqual(list(circuit.tags.all()), list(tags))
+
+        # Verify that changes have been rolled back after exiting the context
+        self.assertEqual(Provider.objects.count(), 3)
+        self.assertEqual(Circuit.objects.count(), 9)
+        self.assertEqual(StagedChange.objects.count(), 5)
+
+        # Verify that changes are replayed upon entering the context
+        with checkout(branch):
+            self.assertEqual(Provider.objects.count(), 4)
+            self.assertEqual(Circuit.objects.count(), 10)
+            provider = Provider.objects.get(name='Provider D')
+            self.assertListEqual(list(provider.asns.all()), list(asns))
+            circuit = Circuit.objects.get(cid='Circuit D1')
+            self.assertListEqual(list(circuit.tags.all()), list(tags))
+
+        # Verify that changes are applied and deleted upon branch merge
+        branch.merge()
+        self.assertEqual(Provider.objects.count(), 4)
+        self.assertEqual(Circuit.objects.count(), 10)
+        provider = Provider.objects.get(name='Provider D')
+        self.assertListEqual(list(provider.asns.all()), list(asns))
+        circuit = Circuit.objects.get(cid='Circuit D1')
+        self.assertListEqual(list(circuit.tags.all()), list(tags))
+        self.assertEqual(StagedChange.objects.count(), 0)
+
+    def test_object_modification(self):
+        branch = Branch.objects.create(name='Branch 1')
+        tags = Tag.objects.all()
+        asns = ASN.objects.all()
+
+        with checkout(branch):
+            provider = Provider.objects.get(name='Provider A')
+            provider.name = 'Provider X'
+            provider.save()
+            provider.asns.set(asns)
+            circuit = Circuit.objects.get(cid='Circuit A1')
+            circuit.cid = 'Circuit X'
+            circuit.save()
+            circuit.tags.set(tags)
+
+            # Sanity-checking
+            self.assertEqual(Provider.objects.count(), 3)
+            self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
+            self.assertListEqual(list(provider.asns.all()), list(asns))
+            self.assertEqual(Circuit.objects.count(), 9)
+            self.assertEqual(Circuit.objects.get(pk=circuit.pk).cid, 'Circuit X')
+            self.assertListEqual(list(circuit.tags.all()), list(tags))
+
+        # Verify that changes have been rolled back after exiting the context
+        self.assertEqual(Provider.objects.count(), 3)
+        self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider A')
+        provider = Provider.objects.get(pk=provider.pk)
+        self.assertListEqual(list(provider.asns.all()), [])
+        self.assertEqual(Circuit.objects.count(), 9)
+        circuit = Circuit.objects.get(pk=circuit.pk)
+        self.assertEqual(circuit.cid, 'Circuit A1')
+        self.assertListEqual(list(circuit.tags.all()), [])
+        self.assertEqual(StagedChange.objects.count(), 5)
+
+        # Verify that changes are replayed upon entering the context
+        with checkout(branch):
+            self.assertEqual(Provider.objects.count(), 3)
+            self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
+            provider = Provider.objects.get(pk=provider.pk)
+            self.assertListEqual(list(provider.asns.all()), list(asns))
+            self.assertEqual(Circuit.objects.count(), 9)
+            circuit = Circuit.objects.get(pk=circuit.pk)
+            self.assertEqual(circuit.cid, 'Circuit X')
+            self.assertListEqual(list(circuit.tags.all()), list(tags))
+
+        # Verify that changes are applied and deleted upon branch merge
+        branch.merge()
+        self.assertEqual(Provider.objects.count(), 3)
+        self.assertEqual(Provider.objects.get(pk=provider.pk).name, 'Provider X')
+        provider = Provider.objects.get(pk=provider.pk)
+        self.assertListEqual(list(provider.asns.all()), list(asns))
+        self.assertEqual(Circuit.objects.count(), 9)
+        circuit = Circuit.objects.get(pk=circuit.pk)
+        self.assertEqual(circuit.cid, 'Circuit X')
+        self.assertListEqual(list(circuit.tags.all()), list(tags))
+        self.assertEqual(StagedChange.objects.count(), 0)
+
+    def test_object_deletion(self):
+        branch = Branch.objects.create(name='Branch 1')
+
+        with checkout(branch):
+            provider = Provider.objects.get(name='Provider A')
+            provider.circuits.all().delete()
+            provider.delete()
+
+            # Sanity-checking
+            self.assertEqual(Provider.objects.count(), 2)
+            self.assertEqual(Circuit.objects.count(), 6)
+
+        # Verify that changes have been rolled back after exiting the context
+        self.assertEqual(Provider.objects.count(), 3)
+        self.assertEqual(Circuit.objects.count(), 9)
+        self.assertEqual(StagedChange.objects.count(), 4)
+
+        # Verify that changes are replayed upon entering the context
+        with checkout(branch):
+            self.assertEqual(Provider.objects.count(), 2)
+            self.assertEqual(Circuit.objects.count(), 6)
+
+        # Verify that changes are applied and deleted upon branch merge
+        branch.merge()
+        self.assertEqual(Provider.objects.count(), 2)
+        self.assertEqual(Circuit.objects.count(), 6)
+        self.assertEqual(StagedChange.objects.count(), 0)
+
+    def test_exit_enter_context(self):
+        branch = Branch.objects.create(name='Branch 1')
+
+        with checkout(branch):
+
+            # Create a new object
+            provider = Provider.objects.create(name='Provider D', slug='provider-d')
+            provider.save()
+
+        # Check that a create Change was recorded
+        self.assertEqual(StagedChange.objects.count(), 1)
+        change = StagedChange.objects.first()
+        self.assertEqual(change.action, ChangeActionChoices.ACTION_CREATE)
+        self.assertEqual(change.data['name'], provider.name)
+
+        with checkout(branch):
+
+            # Update the staged object
+            provider = Provider.objects.get(name='Provider D')
+            provider.comments = 'New comments'
+            provider.save()
+
+        # Check that a second Change object has been created for the object
+        self.assertEqual(StagedChange.objects.count(), 2)
+        change = StagedChange.objects.last()
+        self.assertEqual(change.action, ChangeActionChoices.ACTION_UPDATE)
+        self.assertEqual(change.data['name'], provider.name)
+        self.assertEqual(change.data['comments'], provider.comments)
+
+        with checkout(branch):
+
+            # Delete the staged object
+            provider = Provider.objects.get(name='Provider D')
+            provider.delete()
+
+        # Check that a third Change has recorded the object's deletion
+        self.assertEqual(StagedChange.objects.count(), 3)
+        change = StagedChange.objects.last()
+        self.assertEqual(change.action, ChangeActionChoices.ACTION_DELETE)
+        self.assertIsNone(change.data)

+ 25 - 5
netbox/utilities/utils.py

@@ -6,7 +6,8 @@ from decimal import Decimal
 from itertools import count, groupby
 from itertools import count, groupby
 
 
 import bleach
 import bleach
-from django.core.serializers import serialize
+from django.contrib.contenttypes.models import ContentType
+from django.core import serializers
 from django.db.models import Count, OuterRef, Subquery
 from django.db.models import Count, OuterRef, Subquery
 from django.db.models.functions import Coalesce
 from django.db.models.functions import Coalesce
 from django.http import QueryDict
 from django.http import QueryDict
@@ -135,14 +136,14 @@ def count_related(model, field):
     return Coalesce(subquery, 0)
     return Coalesce(subquery, 0)
 
 
 
 
-def serialize_object(obj, extra=None):
+def serialize_object(obj, resolve_tags=True, extra=None):
     """
     """
     Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
     Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
     change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
     change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
     can be provided to exclude them from the returned dictionary. Private fields (prefaced with an underscore) are
     can be provided to exclude them from the returned dictionary. Private fields (prefaced with an underscore) are
     implicitly excluded.
     implicitly excluded.
     """
     """
-    json_str = serialize('json', [obj])
+    json_str = serializers.serialize('json', [obj])
     data = json.loads(json_str)[0]['fields']
     data = json.loads(json_str)[0]['fields']
 
 
     # Exclude any MPTTModel fields
     # Exclude any MPTTModel fields
@@ -154,8 +155,9 @@ def serialize_object(obj, extra=None):
     if hasattr(obj, 'custom_field_data'):
     if hasattr(obj, 'custom_field_data'):
         data['custom_fields'] = data.pop('custom_field_data')
         data['custom_fields'] = data.pop('custom_field_data')
 
 
-    # Include any tags. Check for tags cached on the instance; fall back to using the manager.
-    if is_taggable(obj):
+    # Resolve any assigned tags to their names. Check for tags cached on the instance;
+    # fall back to using the manager.
+    if resolve_tags and is_taggable(obj):
         tags = getattr(obj, '_tags', None) or obj.tags.all()
         tags = getattr(obj, '_tags', None) or obj.tags.all()
         data['tags'] = sorted([tag.name for tag in tags])
         data['tags'] = sorted([tag.name for tag in tags])
 
 
@@ -172,6 +174,24 @@ def serialize_object(obj, extra=None):
     return data
     return data
 
 
 
 
+def deserialize_object(model, fields, pk=None):
+    """
+    Instantiate an object from the given model and field data. Functions as
+    the complement to serialize_object().
+    """
+    content_type = ContentType.objects.get_for_model(model)
+    if 'custom_fields' in fields:
+        fields['custom_field_data'] = fields.pop('custom_fields')
+    data = {
+        'model': '.'.join(content_type.natural_key()),
+        'pk': pk,
+        'fields': fields,
+    }
+    instance = list(serializers.deserialize('python', [data]))[0]
+
+    return instance
+
+
 def dict_to_filter_params(d, prefix=''):
 def dict_to_filter_params(d, prefix=''):
     """
     """
     Translate a dictionary of attributes to a nested set of parameters suitable for QuerySet filtering. For example:
     Translate a dictionary of attributes to a nested set of parameters suitable for QuerySet filtering. For example: