Sfoglia il codice sorgente

Merge pull request #4879 from netbox-community/4877-users-api-endpoint

4877 users api endpoint
Jeremy Stretch 5 anni fa
parent
commit
1fcefc486c

+ 8 - 0
netbox/netbox/settings.py

@@ -382,6 +382,14 @@ LOGIN_URL = '/{}login/'.format(BASE_PATH)
 
 
 CSRF_TRUSTED_ORIGINS = ALLOWED_HOSTS
 CSRF_TRUSTED_ORIGINS = ALLOWED_HOSTS
 
 
+# Exclude potentially sensitive models from wildcard view exemption. These may still be exempted
+# by specifying the model individually in the EXEMPT_VIEW_PERMISSIONS configuration parameter.
+EXEMPT_EXCLUDE_MODELS = (
+    ('auth', 'group'),
+    ('auth', 'user'),
+    ('users', 'objectpermission'),
+)
+
 #
 #
 # Caching
 # Caching
 #
 #

+ 6 - 3
netbox/users/api/nested_serializers.py

@@ -13,20 +13,23 @@ __all__ = [
 
 
 
 
 class NestedGroupSerializer(WritableNestedSerializer):
 class NestedGroupSerializer(WritableNestedSerializer):
+    url = serializers.HyperlinkedIdentityField(view_name='users-api:group-detail')
 
 
     class Meta:
     class Meta:
         model = Group
         model = Group
-        fields = ['id', 'name']
+        fields = ['id', 'url', 'name']
 
 
 
 
 class NestedUserSerializer(WritableNestedSerializer):
 class NestedUserSerializer(WritableNestedSerializer):
+    url = serializers.HyperlinkedIdentityField(view_name='users-api:user-detail')
 
 
     class Meta:
     class Meta:
         model = User
         model = User
-        fields = ['id', 'username']
+        fields = ['id', 'url', 'username']
 
 
 
 
 class NestedObjectPermissionSerializer(WritableNestedSerializer):
 class NestedObjectPermissionSerializer(WritableNestedSerializer):
+    url = serializers.HyperlinkedIdentityField(view_name='users-api:objectpermission-detail')
     object_types = ContentTypeField(
     object_types = ContentTypeField(
         queryset=ContentType.objects.all(),
         queryset=ContentType.objects.all(),
         many=True
         many=True
@@ -36,7 +39,7 @@ class NestedObjectPermissionSerializer(WritableNestedSerializer):
 
 
     class Meta:
     class Meta:
         model = ObjectPermission
         model = ObjectPermission
-        fields = ['id', 'name', 'enabled', 'object_types', 'groups', 'users', 'actions']
+        fields = ['id', 'url', 'name', 'enabled', 'object_types', 'groups', 'users', 'actions']
 
 
     def get_groups(self, obj):
     def get_groups(self, obj):
         return [g.name for g in obj.groups.all()]
         return [g.name for g in obj.groups.all()]

+ 26 - 0
netbox/users/api/serializers.py

@@ -7,6 +7,32 @@ from utilities.api import ContentTypeField, SerializedPKRelatedField, ValidatedM
 from .nested_serializers import *
 from .nested_serializers import *
 
 
 
 
+class UserSerializer(ValidatedModelSerializer):
+    url = serializers.HyperlinkedIdentityField(view_name='users-api:user-detail')
+    groups = SerializedPKRelatedField(
+        queryset=Group.objects.all(),
+        serializer=NestedGroupSerializer,
+        required=False,
+        many=True
+    )
+
+    class Meta:
+        model = User
+        fields = (
+            'id', 'url', 'username', 'first_name', 'last_name', 'email', 'is_staff', 'is_active', 'date_joined',
+            'groups',
+        )
+
+
+class GroupSerializer(ValidatedModelSerializer):
+    url = serializers.HyperlinkedIdentityField(view_name='users-api:group-detail')
+    user_count = serializers.IntegerField(read_only=True)
+
+    class Meta:
+        model = Group
+        fields = ('id', 'url', 'name', 'user_count')
+
+
 class ObjectPermissionSerializer(ValidatedModelSerializer):
 class ObjectPermissionSerializer(ValidatedModelSerializer):
     url = serializers.HyperlinkedIdentityField(view_name='users-api:objectpermission-detail')
     url = serializers.HyperlinkedIdentityField(view_name='users-api:objectpermission-detail')
     object_types = ContentTypeField(
     object_types = ContentTypeField(

+ 4 - 0
netbox/users/api/urls.py

@@ -14,6 +14,10 @@ class UsersRootView(routers.APIRootView):
 router = routers.DefaultRouter()
 router = routers.DefaultRouter()
 router.APIRootView = UsersRootView
 router.APIRootView = UsersRootView
 
 
+# Users and groups
+router.register('users', views.UserViewSet)
+router.register('groups', views.GroupViewSet)
+
 # Permissions
 # Permissions
 router.register('permissions', views.ObjectPermissionViewSet)
 router.register('permissions', views.ObjectPermissionViewSet)
 
 

+ 21 - 1
netbox/users/api/views.py

@@ -1,7 +1,27 @@
+from django.contrib.auth.models import Group, User
+from django.db.models import Count
+
+from users import filters
+from users.models import ObjectPermission
 from utilities.api import ModelViewSet
 from utilities.api import ModelViewSet
+from utilities.querysets import RestrictedQuerySet
 from . import serializers
 from . import serializers
 
 
-from users.models import ObjectPermission
+
+#
+# Users and groups
+#
+
+class UserViewSet(ModelViewSet):
+    queryset = RestrictedQuerySet(model=User).prefetch_related('groups')
+    serializer_class = serializers.UserSerializer
+    filterset_class = filters.UserFilterSet
+
+
+class GroupViewSet(ModelViewSet):
+    queryset = RestrictedQuerySet(model=Group).annotate(user_count=Count('user'))
+    serializer_class = serializers.GroupSerializer
+    filterset_class = filters.GroupFilterSet
 
 
 
 
 #
 #

+ 58 - 0
netbox/users/filters.py

@@ -0,0 +1,58 @@
+import django_filters
+from django.contrib.auth.models import Group, User
+from django.db.models import Q
+
+from utilities.filters import BaseFilterSet
+
+__all__ = (
+    'GroupFilterSet',
+    'UserFilterSet',
+)
+
+
+class GroupFilterSet(BaseFilterSet):
+    q = django_filters.CharFilter(
+        method='search',
+        label='Search',
+    )
+
+    class Meta:
+        model = Group
+        fields = ['id', 'name']
+
+    def search(self, queryset, name, value):
+        if not value.strip():
+            return queryset
+        return queryset.filter(name__icontains=value)
+
+
+class UserFilterSet(BaseFilterSet):
+    q = django_filters.CharFilter(
+        method='search',
+        label='Search',
+    )
+    group_id = django_filters.ModelMultipleChoiceFilter(
+        field_name='groups',
+        queryset=Group.objects.all(),
+        label='Group',
+    )
+    group = django_filters.ModelMultipleChoiceFilter(
+        field_name='groups__name',
+        queryset=Group.objects.all(),
+        to_field_name='name',
+        label='Group (name)',
+    )
+
+    class Meta:
+        model = User
+        fields = ['id', 'username', 'first_name', 'last_name', 'email', 'is_staff', 'is_active']
+
+    def search(self, queryset, name, value):
+        if not value.strip():
+            return queryset
+        return queryset.filter(
+            Q(username__icontains=value) |
+            Q(first_name__icontains=value) |
+            Q(last_name__icontains=value) |
+            Q(email__icontains=value)
+        )

+ 55 - 15
netbox/users/tests/test_api.py

@@ -18,9 +18,63 @@ class AppTest(APITestCase):
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.status_code, 200)
 
 
 
 
+class UserTest(APIViewTestCases.APIViewTestCase):
+    model = User
+    view_namespace = 'users'
+    brief_fields = ['id', 'url', 'username']
+    create_data = [
+        {
+            'username': 'User_4',
+        },
+        {
+            'username': 'User_5',
+        },
+        {
+            'username': 'User_6',
+        },
+    ]
+
+    @classmethod
+    def setUpTestData(cls):
+
+        users = (
+            User(username='User_1'),
+            User(username='User_2'),
+            User(username='User_3'),
+        )
+        User.objects.bulk_create(users)
+
+
+class GroupTest(APIViewTestCases.APIViewTestCase):
+    model = Group
+    view_namespace = 'users'
+    brief_fields = ['id', 'name', 'url']
+    create_data = [
+        {
+            'name': 'Group 4',
+        },
+        {
+            'name': 'Group 5',
+        },
+        {
+            'name': 'Group 6',
+        },
+    ]
+
+    @classmethod
+    def setUpTestData(cls):
+
+        users = (
+            Group(name='Group 1'),
+            Group(name='Group 2'),
+            Group(name='Group 3'),
+        )
+        Group.objects.bulk_create(users)
+
+
 class ObjectPermissionTest(APIViewTestCases.APIViewTestCase):
 class ObjectPermissionTest(APIViewTestCases.APIViewTestCase):
     model = ObjectPermission
     model = ObjectPermission
-    brief_fields = ['actions', 'enabled', 'groups', 'id', 'name', 'object_types', 'users']
+    brief_fields = ['actions', 'enabled', 'groups', 'id', 'name', 'object_types', 'url', 'users']
 
 
     @classmethod
     @classmethod
     def setUpTestData(cls):
     def setUpTestData(cls):
@@ -74,17 +128,3 @@ class ObjectPermissionTest(APIViewTestCases.APIViewTestCase):
                 'constraints': {'name': 'TEST6'},
                 'constraints': {'name': 'TEST6'},
             },
             },
         ]
         ]
-
-    @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
-    def test_list_objects_anonymous(self):
-        # Endpoint should never be exposed via EXEMPT_VIEW_PERMISSIONS
-        url = self._get_list_url()
-        with disable_warnings('django.request'):
-            self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN)
-
-    @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
-    def test_get_object_anonymous(self):
-        # Endpoint should never be exposed via EXEMPT_VIEW_PERMISSIONS
-        url = self._get_detail_url(self._get_queryset().first())
-        with disable_warnings('django.request'):
-            self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN)

+ 116 - 0
netbox/users/tests/test_filters.py

@@ -0,0 +1,116 @@
+from django.contrib.auth.models import Group, User
+from django.test import TestCase
+
+from users.filters import GroupFilterSet, UserFilterSet
+
+
+class UserTestCase(TestCase):
+    queryset = User.objects.all()
+    filterset = UserFilterSet
+
+    @classmethod
+    def setUpTestData(cls):
+
+        groups = (
+            Group(name='Group 1'),
+            Group(name='Group 2'),
+            Group(name='Group 3'),
+        )
+        Group.objects.bulk_create(groups)
+
+        users = (
+            User(
+                username='User1',
+                first_name='Hank',
+                last_name='Hill',
+                email='hank@stricklandpropane.com',
+                is_staff=True
+            ),
+            User(
+                username='User2',
+                first_name='Dale',
+                last_name='Gribble',
+                email='dale@dalesdeadbug.com'
+            ),
+            User(
+                username='User3',
+                first_name='Bill',
+                last_name='Dauterive',
+                email='bill.dauterive@army.mil'
+            ),
+            User(
+                username='User4',
+                first_name='Jeff',
+                last_name='Boomhauer',
+                email='boomhauer@dangolemail.com'
+            ),
+            User(
+                username='User5',
+                first_name='Debbie',
+                last_name='Grund',
+                is_active=False
+            )
+        )
+        User.objects.bulk_create(users)
+
+        users[0].groups.set([groups[0]])
+        users[1].groups.set([groups[1]])
+        users[2].groups.set([groups[2]])
+
+    def test_id(self):
+        params = {'id': self.queryset.values_list('pk', flat=True)[:2]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+    def test_username(self):
+        params = {'username': ['User1', 'User2']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+    def test_first_name(self):
+        params = {'first_name': ['Hank', 'Dale']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+    def test_last_name(self):
+        params = {'last_name': ['Hill', 'Gribble']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+    def test_email(self):
+        params = {'email': ['hank@stricklandpropane.com', 'dale@dalesdeadbug.com']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+    def test_is_staff(self):
+        params = {'is_staff': True}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
+
+    def test_is_active(self):
+        params = {'is_active': True}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
+
+    def test_group(self):
+        groups = Group.objects.all()[:2]
+        params = {'group_id': [groups[0].pk, groups[1].pk]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        params = {'group': [groups[0].name, groups[1].name]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+
+class GroupTestCase(TestCase):
+    queryset = Group.objects.all()
+    filterset = GroupFilterSet
+
+    @classmethod
+    def setUpTestData(cls):
+
+        groups = (
+            Group(name='Group 1'),
+            Group(name='Group 2'),
+            Group(name='Group 3'),
+        )
+        Group.objects.bulk_create(groups)
+
+    def test_id(self):
+        params = {'id': self.queryset.values_list('pk', flat=True)[:2]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
+    def test_name(self):
+        params = {'name': ['Group 1', 'Group 2']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)

+ 4 - 3
netbox/utilities/api.py

@@ -32,9 +32,10 @@ def get_serializer_for_model(model, prefix=''):
     Dynamically resolve and return the appropriate serializer for a model.
     Dynamically resolve and return the appropriate serializer for a model.
     """
     """
     app_name, model_name = model._meta.label.split('.')
     app_name, model_name = model._meta.label.split('.')
-    serializer_name = '{}.api.serializers.{}{}Serializer'.format(
-        app_name, prefix, model_name
-    )
+    # Serializers for Django's auth models are in the users app
+    if app_name == 'auth':
+        app_name = 'users'
+    serializer_name = f'{app_name}.api.serializers.{prefix}{model_name}Serializer'
     try:
     try:
         return dynamic_import(serializer_name)
         return dynamic_import(serializer_name)
     except AttributeError:
     except AttributeError:

+ 1 - 7
netbox/utilities/permissions.py

@@ -1,12 +1,6 @@
 from django.conf import settings
 from django.conf import settings
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 
 
-# Exclude potentially sensitive models from wild view exemption. These may still be exempted
-# by specifying the model individually in the EXEMPT_VIEW_PERMISSIONS configuration parameter.
-EXEMPT_EXCLUDE_MODELS = (
-    ('users', 'objectpermission'),
-)
-
 
 
 def get_permission_for_model(model, action):
 def get_permission_for_model(model, action):
     """
     """
@@ -70,7 +64,7 @@ def permission_is_exempt(name):
     if action == 'view':
     if action == 'view':
         if (
         if (
             # All models (excluding those in EXEMPT_EXCLUDE_MODELS) are exempt from view permission enforcement
             # All models (excluding those in EXEMPT_EXCLUDE_MODELS) are exempt from view permission enforcement
-            '*' in settings.EXEMPT_VIEW_PERMISSIONS and (app_label, model_name) not in EXEMPT_EXCLUDE_MODELS
+            '*' in settings.EXEMPT_VIEW_PERMISSIONS and (app_label, model_name) not in settings.EXEMPT_EXCLUDE_MODELS
         ) or (
         ) or (
             # This specific model is exempt from view permission enforcement
             # This specific model is exempt from view permission enforcement
             f'{app_label}.{model_name}' in settings.EXEMPT_VIEW_PERMISSIONS
             f'{app_label}.{model_name}' in settings.EXEMPT_VIEW_PERMISSIONS

+ 28 - 8
netbox/utilities/testing/api.py

@@ -1,3 +1,4 @@
+from django.conf import settings
 from django.contrib.auth.models import User
 from django.contrib.auth.models import User
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.urls import reverse
 from django.urls import reverse
@@ -21,7 +22,14 @@ __all__ = (
 #
 #
 
 
 class APITestCase(ModelTestCase):
 class APITestCase(ModelTestCase):
+    """
+    Base test case for API requests.
+
+    client_class: Test client class
+    view_namespace: Namespace for API views. If None, the model's app_label will be used.
+    """
     client_class = APIClient
     client_class = APIClient
+    view_namespace = None
 
 
     def setUp(self):
     def setUp(self):
         """
         """
@@ -33,12 +41,15 @@ class APITestCase(ModelTestCase):
         self.token = Token.objects.create(user=self.user)
         self.token = Token.objects.create(user=self.user)
         self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
         self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
 
 
+    def _get_view_namespace(self):
+        return f'{self.view_namespace or self.model._meta.app_label}-api'
+
     def _get_detail_url(self, instance):
     def _get_detail_url(self, instance):
-        viewname = f'{instance._meta.app_label}-api:{instance._meta.model_name}-detail'
+        viewname = f'{self._get_view_namespace()}:{instance._meta.model_name}-detail'
         return reverse(viewname, kwargs={'pk': instance.pk})
         return reverse(viewname, kwargs={'pk': instance.pk})
 
 
     def _get_list_url(self):
     def _get_list_url(self):
-        viewname = f'{self.model._meta.app_label}-api:{self.model._meta.model_name}-list'
+        viewname = f'{self._get_view_namespace()}:{self.model._meta.model_name}-list'
         return reverse(viewname)
         return reverse(viewname)
 
 
 
 
@@ -52,8 +63,13 @@ class APIViewTestCases:
             GET a single object as an unauthenticated user.
             GET a single object as an unauthenticated user.
             """
             """
             url = self._get_detail_url(self._get_queryset().first())
             url = self._get_detail_url(self._get_queryset().first())
-            response = self.client.get(url, **self.header)
-            self.assertHttpStatus(response, status.HTTP_200_OK)
+            if (self.model._meta.app_label, self.model._meta.model_name) in settings.EXEMPT_EXCLUDE_MODELS:
+                # Models listed in EXEMPT_EXCLUDE_MODELS should not be accessible to anonymous users
+                with disable_warnings('django.request'):
+                    self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN)
+            else:
+                response = self.client.get(url, **self.header)
+                self.assertHttpStatus(response, status.HTTP_200_OK)
 
 
         @override_settings(EXEMPT_VIEW_PERMISSIONS=[])
         @override_settings(EXEMPT_VIEW_PERMISSIONS=[])
         def test_get_object_without_permission(self):
         def test_get_object_without_permission(self):
@@ -101,10 +117,14 @@ class APIViewTestCases:
             GET a list of objects as an unauthenticated user.
             GET a list of objects as an unauthenticated user.
             """
             """
             url = self._get_list_url()
             url = self._get_list_url()
-            response = self.client.get(url, **self.header)
-
-            self.assertHttpStatus(response, status.HTTP_200_OK)
-            self.assertEqual(len(response.data['results']), self._get_queryset().count())
+            if (self.model._meta.app_label, self.model._meta.model_name) in settings.EXEMPT_EXCLUDE_MODELS:
+                # Models listed in EXEMPT_EXCLUDE_MODELS should not be accessible to anonymous users
+                with disable_warnings('django.request'):
+                    self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN)
+            else:
+                response = self.client.get(url, **self.header)
+                self.assertHttpStatus(response, status.HTTP_200_OK)
+                self.assertEqual(len(response.data['results']), self._get_queryset().count())
 
 
         @override_settings(EXEMPT_VIEW_PERMISSIONS=[])
         @override_settings(EXEMPT_VIEW_PERMISSIONS=[])
         def test_list_objects_brief(self):
         def test_list_objects_brief(self):