Parcourir la source

12794 change User ref to get_user_model (#12905)

* 12794 change User ref to get_user_model

* 12794 call get_user_model once in tests

* 12794 call get_user_model once in tests

* 12794 use settings.AUTH_USER_MODEL for FK reference
Arthur Hanson il y a 2 ans
Parent
commit
518fd8cca6

+ 2 - 2
netbox/core/forms/filtersets.py

@@ -1,5 +1,5 @@
 from django import forms
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.utils.translation import gettext as _
 
@@ -105,7 +105,7 @@ class JobFilterForm(SavedFiltersMixin, FilterForm):
         widget=DateTimePicker()
     )
     user = DynamicModelMultipleChoiceField(
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         required=False,
         label=_('User'),
         widget=APISelectMultiple(

+ 2 - 2
netbox/core/management/commands/nbshell.py

@@ -5,7 +5,7 @@ import sys
 from django import get_version
 from django.apps import apps
 from django.conf import settings
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.core.management.base import BaseCommand
 
@@ -60,7 +60,7 @@ class Command(BaseCommand):
 
         # Additional objects to include
         namespace['ContentType'] = ContentType
-        namespace['User'] = User
+        namespace['User'] = get_user_model()
 
         # Load convenience commands
         namespace.update({

+ 2 - 2
netbox/core/models/jobs.py

@@ -1,7 +1,7 @@
 import uuid
 
 import django_rq
-from django.contrib.auth.models import User
+from django.conf import settings
 from django.contrib.contenttypes.fields import GenericForeignKey
 from django.contrib.contenttypes.models import ContentType
 from django.core.validators import MinValueValidator
@@ -69,7 +69,7 @@ class Job(models.Model):
         blank=True
     )
     user = models.ForeignKey(
-        to=User,
+        to=settings.AUTH_USER_MODEL,
         on_delete=models.SET_NULL,
         related_name='+',
         blank=True,

+ 3 - 3
netbox/dcim/filtersets.py

@@ -1,5 +1,5 @@
 import django_filters
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.utils.translation import gettext as _
 
 from extras.filtersets import LocalConfigContextFilterSet
@@ -395,12 +395,12 @@ class RackReservationFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
         label=_('Location (slug)'),
     )
     user_id = django_filters.ModelMultipleChoiceFilter(
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         label=_('User (ID)'),
     )
     user = django_filters.ModelMultipleChoiceFilter(
         field_name='user__username',
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         to_field_name='username',
         label=_('User (name)'),
     )

+ 2 - 2
netbox/dcim/forms/bulk_edit.py

@@ -1,6 +1,6 @@
 from django import forms
 from django.conf import settings
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.utils.translation import gettext as _
 from timezone_field import TimeZoneFormField
 
@@ -322,7 +322,7 @@ class RackBulkEditForm(NetBoxModelBulkEditForm):
 
 class RackReservationBulkEditForm(NetBoxModelBulkEditForm):
     user = forms.ModelChoiceField(
-        queryset=User.objects.order_by(
+        queryset=get_user_model().objects.order_by(
             'username'
         ),
         required=False

+ 2 - 2
netbox/dcim/forms/filtersets.py

@@ -1,5 +1,5 @@
 from django import forms
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.utils.translation import gettext as _
 
 from dcim.choices import *
@@ -376,7 +376,7 @@ class RackReservationFilterForm(TenancyFilterForm, NetBoxModelFilterSetForm):
         label=_('Rack')
     )
     user_id = DynamicModelMultipleChoiceField(
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         required=False,
         label=_('User'),
         widget=APISelectMultiple(

+ 2 - 2
netbox/dcim/forms/model_forms.py

@@ -1,5 +1,5 @@
 from django import forms
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.utils.translation import gettext as _
 from timezone_field import TimeZoneFormField
@@ -236,7 +236,7 @@ class RackReservationForm(TenancyForm, NetBoxModelForm):
         help_text=_("Comma-separated list of numeric unit IDs. A range may be specified using a hyphen.")
     )
     user = forms.ModelChoiceField(
-        queryset=User.objects.order_by(
+        queryset=get_user_model().objects.order_by(
             'username'
         )
     )

+ 2 - 2
netbox/dcim/models/racks.py

@@ -1,7 +1,7 @@
 import decimal
 from functools import cached_property
 
-from django.contrib.auth.models import User
+from django.conf import settings
 from django.contrib.contenttypes.fields import GenericRelation
 from django.contrib.postgres.fields import ArrayField
 from django.core.exceptions import ValidationError
@@ -505,7 +505,7 @@ class RackReservation(PrimaryModel):
         null=True
     )
     user = models.ForeignKey(
-        to=User,
+        to=settings.AUTH_USER_MODEL,
         on_delete=models.PROTECT
     )
     description = models.CharField(

+ 4 - 1
netbox/dcim/tests/test_api.py

@@ -1,4 +1,4 @@
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.test import override_settings
 from django.urls import reverse
 from rest_framework import status
@@ -14,6 +14,9 @@ from wireless.choices import WirelessChannelChoices
 from wireless.models import WirelessLAN
 
 
+User = get_user_model()
+
+
 class AppTest(APITestCase):
 
     def test_root(self):

+ 4 - 1
netbox/dcim/tests/test_filtersets.py

@@ -1,4 +1,4 @@
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.test import TestCase
 
 from dcim.choices import *
@@ -12,6 +12,9 @@ from virtualization.models import Cluster, ClusterType
 from wireless.choices import WirelessChannelChoices, WirelessRoleChoices
 
 
+User = get_user_model()
+
+
 class DeviceComponentFilterSetTests:
 
     def test_device_type(self):

+ 4 - 1
netbox/dcim/tests/test_views.py

@@ -6,7 +6,7 @@ except ImportError:
     from backports.zoneinfo import ZoneInfo
 
 import yaml
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.test import override_settings
 from django.urls import reverse
@@ -22,6 +22,9 @@ from utilities.testing import ViewTestCases, create_tags, create_test_device, po
 from wireless.models import WirelessLAN
 
 
+User = get_user_model()
+
+
 class RegionTestCase(ViewTestCases.OrganizationalObjectViewTestCase):
     model = Region
 

+ 2 - 2
netbox/extras/api/serializers.py

@@ -1,4 +1,4 @@
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.core.exceptions import ObjectDoesNotExist
 from rest_framework import serializers
@@ -256,7 +256,7 @@ class JournalEntrySerializer(NetBoxModelSerializer):
     assigned_object = serializers.SerializerMethodField(read_only=True)
     created_by = serializers.PrimaryKeyRelatedField(
         allow_null=True,
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         required=False,
         default=serializers.CurrentUserDefault()
     )

+ 7 - 7
netbox/extras/filtersets.py

@@ -1,5 +1,5 @@
 import django_filters
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.db.models import Q
 from django.utils.translation import gettext as _
@@ -159,12 +159,12 @@ class SavedFilterFilterSet(BaseFilterSet):
     )
     content_types = ContentTypeFilter()
     user_id = django_filters.ModelMultipleChoiceFilter(
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         label=_('User (ID)'),
     )
     user = django_filters.ModelMultipleChoiceFilter(
         field_name='user__username',
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         to_field_name='username',
         label=_('User (name)'),
     )
@@ -223,12 +223,12 @@ class JournalEntryFilterSet(NetBoxModelFilterSet):
         queryset=ContentType.objects.all()
     )
     created_by_id = django_filters.ModelMultipleChoiceFilter(
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         label=_('User (ID)'),
     )
     created_by = django_filters.ModelMultipleChoiceFilter(
         field_name='created_by__username',
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         to_field_name='username',
         label=_('User (name)'),
     )
@@ -510,12 +510,12 @@ class ObjectChangeFilterSet(BaseFilterSet):
         queryset=ContentType.objects.all()
     )
     user_id = django_filters.ModelMultipleChoiceFilter(
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         label=_('User (ID)'),
     )
     user = django_filters.ModelMultipleChoiceFilter(
         field_name='user__username',
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         to_field_name='username',
         label=_('User name'),
     )

+ 3 - 3
netbox/extras/forms/filtersets.py

@@ -1,5 +1,5 @@
 from django import forms
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.utils.translation import gettext as _
 
@@ -385,7 +385,7 @@ class JournalEntryFilterForm(NetBoxModelFilterSetForm):
         widget=DateTimePicker()
     )
     created_by_id = DynamicModelMultipleChoiceField(
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         required=False,
         label=_('User'),
         widget=APISelectMultiple(
@@ -429,7 +429,7 @@ class ObjectChangeFilterForm(SavedFiltersMixin, FilterForm):
         required=False
     )
     user_id = DynamicModelMultipleChoiceField(
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         required=False,
         label=_('User'),
         widget=APISelectMultiple(

+ 3 - 1
netbox/extras/management/commands/runscript.py

@@ -4,7 +4,7 @@ import sys
 import traceback
 import uuid
 
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.core.management.base import BaseCommand, CommandError
 from django.db import transaction
 
@@ -63,6 +63,8 @@ class Command(BaseCommand):
 
             logger.info(f"Script completed in {job.duration}")
 
+        User = get_user_model()
+
         # Params
         script = options['script']
         loglevel = options['loglevel']

+ 2 - 2
netbox/extras/models/change_logging.py

@@ -1,4 +1,4 @@
-from django.contrib.auth.models import User
+from django.conf import settings
 from django.contrib.contenttypes.fields import GenericForeignKey
 from django.contrib.contenttypes.models import ContentType
 from django.db import models
@@ -24,7 +24,7 @@ class ObjectChange(models.Model):
         db_index=True
     )
     user = models.ForeignKey(
-        to=User,
+        to=settings.AUTH_USER_MODEL,
         on_delete=models.SET_NULL,
         related_name='changes',
         blank=True,

+ 3 - 3
netbox/extras/models/models.py

@@ -3,7 +3,7 @@ import urllib.parse
 
 from django.conf import settings
 from django.contrib import admin
-from django.contrib.auth.models import User
+from django.conf import settings
 from django.contrib.contenttypes.fields import GenericForeignKey
 from django.contrib.contenttypes.models import ContentType
 from django.core.cache import cache
@@ -419,7 +419,7 @@ class SavedFilter(CloningMixin, ExportTemplatesMixin, ChangeLoggedModel):
         blank=True
     )
     user = models.ForeignKey(
-        to=User,
+        to=settings.AUTH_USER_MODEL,
         on_delete=models.SET_NULL,
         blank=True,
         null=True
@@ -560,7 +560,7 @@ class JournalEntry(CustomFieldsMixin, CustomLinksMixin, TagsMixin, ExportTemplat
         fk_field='assigned_object_id'
     )
     created_by = models.ForeignKey(
-        to=User,
+        to=settings.AUTH_USER_MODEL,
         on_delete=models.SET_NULL,
         blank=True,
         null=True

+ 4 - 1
netbox/extras/tests/test_api.py

@@ -1,6 +1,6 @@
 import datetime
 
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.urls import reverse
 from django.utils.timezone import make_aware
@@ -15,6 +15,9 @@ from extras.scripts import BooleanVar, IntegerVar, Script, StringVar
 from utilities.testing import APITestCase, APIViewTestCases
 
 
+User = get_user_model()
+
+
 class AppTest(APITestCase):
 
     def test_root(self):

+ 4 - 1
netbox/extras/tests/test_filtersets.py

@@ -1,7 +1,7 @@
 import uuid
 from datetime import datetime, timezone
 
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.test import TestCase
 
@@ -18,6 +18,9 @@ from utilities.testing import BaseFilterSetTests, ChangeLoggedFilterSetTests, cr
 from virtualization.models import Cluster, ClusterGroup, ClusterType
 
 
+User = get_user_model()
+
+
 class CustomFieldTestCase(TestCase, BaseFilterSetTests):
     queryset = CustomField.objects.all()
     filterset = CustomFieldFilterSet

+ 4 - 1
netbox/extras/tests/test_views.py

@@ -1,7 +1,7 @@
 import urllib.parse
 import uuid
 
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.urls import reverse
 
@@ -11,6 +11,9 @@ from extras.models import *
 from utilities.testing import ViewTestCases, TestCase
 
 
+User = get_user_model()
+
+
 class CustomFieldTestCase(ViewTestCases.PrimaryObjectViewTestCase):
     model = CustomField
 

+ 5 - 1
netbox/netbox/tests/test_authentication.py

@@ -1,7 +1,8 @@
 import datetime
 
 from django.conf import settings
-from django.contrib.auth.models import Group, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
 from django.contrib.contenttypes.models import ContentType
 from django.test import Client
 from django.test.utils import override_settings
@@ -16,6 +17,9 @@ from utilities.testing import TestCase
 from utilities.testing.api import APITestCase
 
 
+User = get_user_model()
+
+
 class TokenAuthenticationTestCase(APITestCase):
 
     @override_settings(LOGIN_REQUIRED=True, EXEMPT_VIEW_PERMISSIONS=['*'])

+ 3 - 2
netbox/users/api/nested_serializers.py

@@ -1,4 +1,5 @@
-from django.contrib.auth.models import Group, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
 from django.contrib.contenttypes.models import ContentType
 from drf_spectacular.utils import extend_schema_field
 from drf_spectacular.types import OpenApiTypes
@@ -28,7 +29,7 @@ class NestedUserSerializer(WritableNestedSerializer):
     url = serializers.HyperlinkedIdentityField(view_name='users-api:user-detail')
 
     class Meta:
-        model = User
+        model = get_user_model()
         fields = ['id', 'url', 'display', 'username']
 
     @extend_schema_field(OpenApiTypes.STR)

+ 4 - 3
netbox/users/api/serializers.py

@@ -1,5 +1,6 @@
 from django.conf import settings
-from django.contrib.auth.models import Group, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
 from django.contrib.contenttypes.models import ContentType
 from drf_spectacular.utils import extend_schema_field
 from drf_spectacular.types import OpenApiTypes
@@ -30,7 +31,7 @@ class UserSerializer(ValidatedModelSerializer):
     )
 
     class Meta:
-        model = User
+        model = get_user_model()
         fields = (
             'id', 'url', 'display', 'username', 'password', 'first_name', 'last_name', 'email', 'is_staff', 'is_active',
             'date_joined', 'groups',
@@ -124,7 +125,7 @@ class ObjectPermissionSerializer(ValidatedModelSerializer):
         many=True
     )
     users = SerializedPKRelatedField(
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         serializer=NestedUserSerializer,
         required=False,
         many=True

+ 3 - 2
netbox/users/api/views.py

@@ -1,5 +1,6 @@
 from django.contrib.auth import authenticate
-from django.contrib.auth.models import Group, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
 from django.db.models import Count
 from drf_spectacular.utils import extend_schema
 from drf_spectacular.types import OpenApiTypes
@@ -32,7 +33,7 @@ class UsersRootView(APIRootView):
 #
 
 class UserViewSet(NetBoxModelViewSet):
-    queryset = RestrictedQuerySet(model=User).prefetch_related('groups').order_by('username')
+    queryset = RestrictedQuerySet(model=get_user_model()).prefetch_related('groups').order_by('username')
     serializer_class = serializers.UserSerializer
     filterset_class = filtersets.UserFilterSet
 

+ 7 - 6
netbox/users/filtersets.py

@@ -1,5 +1,6 @@
 import django_filters
-from django.contrib.auth.models import Group, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
 from django.db.models import Q
 from django.utils.translation import gettext as _
 
@@ -47,7 +48,7 @@ class UserFilterSet(BaseFilterSet):
     )
 
     class Meta:
-        model = User
+        model = get_user_model()
         fields = ['id', 'username', 'first_name', 'last_name', 'email', 'is_staff', 'is_active']
 
     def search(self, queryset, name, value):
@@ -68,12 +69,12 @@ class TokenFilterSet(BaseFilterSet):
     )
     user_id = django_filters.ModelMultipleChoiceFilter(
         field_name='user',
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         label=_('User'),
     )
     user = django_filters.ModelMultipleChoiceFilter(
         field_name='user__username',
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         to_field_name='username',
         label=_('User (name)'),
     )
@@ -116,12 +117,12 @@ class ObjectPermissionFilterSet(BaseFilterSet):
     )
     user_id = django_filters.ModelMultipleChoiceFilter(
         field_name='users',
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         label=_('User'),
     )
     user = django_filters.ModelMultipleChoiceFilter(
         field_name='users__username',
-        queryset=User.objects.all(),
+        queryset=get_user_model().objects.all(),
         to_field_name='username',
         label=_('User (name)'),
     )

+ 3 - 2
netbox/users/graphql/schema.py

@@ -1,6 +1,7 @@
 import graphene
 
-from django.contrib.auth.models import Group, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
 from netbox.graphql.fields import ObjectField, ObjectListField
 from .types import *
 from utilities.graphql_optimizer import gql_query_optimizer
@@ -17,4 +18,4 @@ class UsersQuery(graphene.ObjectType):
     user_list = ObjectListField(UserType)
 
     def resolve_user_list(root, info, **kwargs):
-        return gql_query_optimizer(User.objects.all(), info)
+        return gql_query_optimizer(get_user_model().objects.all(), info)

+ 4 - 3
netbox/users/graphql/types.py

@@ -1,4 +1,5 @@
-from django.contrib.auth.models import Group, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
 from graphene_django import DjangoObjectType
 
 from users import filtersets
@@ -25,7 +26,7 @@ class GroupType(DjangoObjectType):
 class UserType(DjangoObjectType):
 
     class Meta:
-        model = User
+        model = get_user_model()
         fields = (
             'id', 'username', 'password', 'first_name', 'last_name', 'email', 'is_staff', 'is_active', 'date_joined',
             'groups',
@@ -34,4 +35,4 @@ class UserType(DjangoObjectType):
 
     @classmethod
     def get_queryset(cls, queryset, info):
-        return RestrictedQuerySet(model=User).restrict(info.context.user, 'view')
+        return RestrictedQuerySet(model=get_user_model()).restrict(info.context.user, 'view')

+ 5 - 1
netbox/users/tests/test_api.py

@@ -1,4 +1,5 @@
-from django.contrib.auth.models import Group, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
 from django.contrib.contenttypes.models import ContentType
 from django.urls import reverse
 
@@ -7,6 +8,9 @@ from utilities.testing import APIViewTestCases, APITestCase
 from utilities.utils import deepmerge
 
 
+User = get_user_model()
+
+
 class AppTest(APITestCase):
 
     def test_root(self):

+ 5 - 1
netbox/users/tests/test_filtersets.py

@@ -1,6 +1,7 @@
 import datetime
 
-from django.contrib.auth.models import Group, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
 from django.contrib.contenttypes.models import ContentType
 from django.test import TestCase
 from django.utils.timezone import make_aware
@@ -10,6 +11,9 @@ from users.models import ObjectPermission, Token
 from utilities.testing import BaseFilterSetTests
 
 
+User = get_user_model()
+
+
 class UserTestCase(TestCase, BaseFilterSetTests):
     queryset = User.objects.all()
     filterset = filtersets.UserFilterSet

+ 4 - 1
netbox/users/tests/test_models.py

@@ -1,7 +1,10 @@
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.test import TestCase
 
 
+User = get_user_model()
+
+
 class UserConfigTest(TestCase):
 
     @classmethod

+ 4 - 1
netbox/users/tests/test_preferences.py

@@ -1,4 +1,4 @@
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.test import override_settings
 from django.test.client import RequestFactory
 from django.urls import reverse
@@ -16,6 +16,9 @@ DEFAULT_USER_PREFERENCES = {
 }
 
 
+User = get_user_model()
+
+
 class UserPreferencesTest(TestCase):
     user_permissions = ['dcim.view_site']
 

+ 4 - 1
netbox/utilities/testing/api.py

@@ -2,7 +2,7 @@ import inspect
 import json
 
 from django.conf import settings
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.urls import reverse
 from django.test import override_settings
@@ -26,6 +26,9 @@ __all__ = (
 )
 
 
+User = get_user_model()
+
+
 #
 # REST/GraphQL API Tests
 #

+ 2 - 2
netbox/utilities/testing/base.py

@@ -1,6 +1,6 @@
 import json
 
-from django.contrib.auth.models import User
+from django.contrib.auth import get_user_model
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.postgres.fields import ArrayField
 from django.core.exceptions import FieldDoesNotExist
@@ -27,7 +27,7 @@ class TestCase(_TestCase):
     def setUp(self):
 
         # Create the test user and assign permissions
-        self.user = User.objects.create_user(username='testuser')
+        self.user = get_user_model().objects.create_user(username='testuser')
         self.add_permissions(*self.user_permissions)
 
         # Initialize the test client

+ 3 - 2
netbox/utilities/testing/utils.py

@@ -2,7 +2,8 @@ import logging
 import re
 from contextlib import contextmanager
 
-from django.contrib.auth.models import Permission, User
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Permission
 from django.utils.text import slugify
 
 from dcim.models import Device, DeviceRole, DeviceType, Manufacturer, Site
@@ -63,7 +64,7 @@ def create_test_user(username='testuser', permissions=None):
     """
     Create a User with the given permissions.
     """
-    user = User.objects.create_user(username=username)
+    user = get_user_model().objects.create_user(username=username)
     if permissions is None:
         permissions = ()
     for perm_name in permissions: