Procházet zdrojové kódy

Closes #5121: Add content_type filters for tags

jeremystretch před 4 roky
rodič
revize
5b4793a2d5

+ 33 - 1
netbox/extras/filtersets.py

@@ -6,7 +6,7 @@ from django.db.models import Q
 from dcim.models import DeviceRole, DeviceType, Platform, Region, Site, SiteGroup
 from dcim.models import DeviceRole, DeviceType, Platform, Region, Site, SiteGroup
 from netbox.filtersets import BaseFilterSet, ChangeLoggedModelFilterSet
 from netbox.filtersets import BaseFilterSet, ChangeLoggedModelFilterSet
 from tenancy.models import Tenant, TenantGroup
 from tenancy.models import Tenant, TenantGroup
-from utilities.filters import ContentTypeFilter
+from utilities.filters import ContentTypeFilter, MultiValueCharFilter, MultiValueNumberFilter
 from virtualization.models import Cluster, ClusterGroup
 from virtualization.models import Cluster, ClusterGroup
 from .choices import *
 from .choices import *
 from .models import *
 from .models import *
@@ -114,6 +114,12 @@ class TagFilterSet(ChangeLoggedModelFilterSet):
         method='search',
         method='search',
         label='Search',
         label='Search',
     )
     )
+    content_type = MultiValueCharFilter(
+        method='_content_type'
+    )
+    content_type_id = MultiValueNumberFilter(
+        method='_content_type_id'
+    )
 
 
     class Meta:
     class Meta:
         model = Tag
         model = Tag
@@ -127,6 +133,32 @@ class TagFilterSet(ChangeLoggedModelFilterSet):
             Q(slug__icontains=value)
             Q(slug__icontains=value)
         )
         )
 
 
+    def _content_type(self, queryset, name, values):
+        ct_filter = Q()
+
+        # Compile list of app_label & model pairings
+        for value in values:
+            try:
+                app_label, model = value.lower().split('.')
+                ct_filter |= Q(
+                    app_label=app_label,
+                    model=model
+                )
+            except ValueError:
+                pass
+
+        # Get ContentType instances
+        content_types = ContentType.objects.filter(ct_filter)
+
+        return queryset.filter(extras_taggeditem_items__content_type__in=content_types).distinct()
+
+    def _content_type_id(self, queryset, name, values):
+
+        # Get ContentType instances
+        content_types = ContentType.objects.filter(pk__in=values)
+
+        return queryset.filter(extras_taggeditem_items__content_type__in=content_types).distinct()
+
 
 
 class ConfigContextFilterSet(ChangeLoggedModelFilterSet):
 class ConfigContextFilterSet(ChangeLoggedModelFilterSet):
     q = django_filters.CharFilter(
     q = django_filters.CharFilter(

+ 8 - 2
netbox/extras/forms.py

@@ -8,12 +8,13 @@ from dcim.models import DeviceRole, DeviceType, Platform, Region, Site, SiteGrou
 from tenancy.models import Tenant, TenantGroup
 from tenancy.models import Tenant, TenantGroup
 from utilities.forms import (
 from utilities.forms import (
     add_blank_choice, APISelectMultiple, BootstrapMixin, BulkEditForm, BulkEditNullBooleanSelect, ColorSelect,
     add_blank_choice, APISelectMultiple, BootstrapMixin, BulkEditForm, BulkEditNullBooleanSelect, ColorSelect,
-    CommentField, CSVModelForm, DateTimePicker, DynamicModelMultipleChoiceField, JSONField, SlugField, StaticSelect2,
-    BOOLEAN_WITH_BLANK_CHOICES,
+    CommentField, ContentTypeMultipleChoiceField, CSVModelForm, DateTimePicker, DynamicModelMultipleChoiceField,
+    JSONField, SlugField, StaticSelect2, BOOLEAN_WITH_BLANK_CHOICES,
 )
 )
 from virtualization.models import Cluster, ClusterGroup
 from virtualization.models import Cluster, ClusterGroup
 from .choices import *
 from .choices import *
 from .models import ConfigContext, CustomField, ImageAttachment, JournalEntry, ObjectChange, Tag
 from .models import ConfigContext, CustomField, ImageAttachment, JournalEntry, ObjectChange, Tag
+from .utils import FeatureQuery
 
 
 
 
 #
 #
@@ -180,6 +181,11 @@ class TagFilterForm(BootstrapMixin, forms.Form):
         required=False,
         required=False,
         label=_('Search')
         label=_('Search')
     )
     )
+    content_type_id = ContentTypeMultipleChoiceField(
+        queryset=ContentType.objects.filter(FeatureQuery('tags').get_query()),
+        required=False,
+        label=_('Tagged object type')
+    )
 
 
 
 
 class TagBulkEditForm(BootstrapMixin, BulkEditForm):
 class TagBulkEditForm(BootstrapMixin, BulkEditForm):

+ 16 - 0
netbox/extras/tests/test_filtersets.py

@@ -5,6 +5,7 @@ from django.contrib.auth.models import User
 from django.contrib.contenttypes.models import ContentType
 from django.contrib.contenttypes.models import ContentType
 from django.test import TestCase
 from django.test import TestCase
 
 
+from circuits.models import Provider
 from dcim.models import DeviceRole, DeviceType, Manufacturer, Platform, Rack, Region, Site, SiteGroup
 from dcim.models import DeviceRole, DeviceType, Manufacturer, Platform, Rack, Region, Site, SiteGroup
 from extras.choices import JournalEntryKindChoices, ObjectChangeActionChoices
 from extras.choices import JournalEntryKindChoices, ObjectChangeActionChoices
 from extras.filtersets import *
 from extras.filtersets import *
@@ -537,6 +538,13 @@ class TagTestCase(TestCase, ChangeLoggedFilterSetTests):
         )
         )
         Tag.objects.bulk_create(tags)
         Tag.objects.bulk_create(tags)
 
 
+        # Apply some tags so we can filter by content type
+        site = Site.objects.create(name='Site 1', slug='site-1')
+        provider = Provider.objects.create(name='Provider 1', slug='provider-1')
+
+        site.tags.set(tags[0])
+        provider.tags.set(tags[1])
+
     def test_name(self):
     def test_name(self):
         params = {'name': ['Tag 1', 'Tag 2']}
         params = {'name': ['Tag 1', 'Tag 2']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
@@ -549,6 +557,14 @@ class TagTestCase(TestCase, ChangeLoggedFilterSetTests):
         params = {'color': ['ff0000', '00ff00']}
         params = {'color': ['ff0000', '00ff00']}
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
         self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
 
 
+    def test_content_type(self):
+        params = {'content_type': ['dcim.site', 'circuits.provider']}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+        site_ct = ContentType.objects.get_for_model(Site).pk
+        provider_ct = ContentType.objects.get_for_model(Provider).pk
+        params = {'content_type_id': [site_ct, provider_ct]}
+        self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
+
 
 
 class ObjectChangeTestCase(TestCase, BaseFilterSetTests):
 class ObjectChangeTestCase(TestCase, BaseFilterSetTests):
     queryset = ObjectChange.objects.all()
     queryset = ObjectChange.objects.all()