Explorar o código

Fixes #21538: Fix annotated count for contacts assigned to multiple contact groups (#21919)

Jeremy Stretch hai 1 mes
pai
achega
1af320e0a9

+ 1 - 7
netbox/tenancy/api/views.py

@@ -42,13 +42,7 @@ class TenantViewSet(NetBoxModelViewSet):
 #
 
 class ContactGroupViewSet(MPTTLockedMixin, NetBoxModelViewSet):
-    queryset = ContactGroup.objects.add_related_count(
-        ContactGroup.objects.all(),
-        Contact,
-        'groups',
-        'contact_count',
-        cumulative=True
-    )
+    queryset = ContactGroup.objects.annotate_contacts()
     serializer_class = serializers.ContactGroupSerializer
     filterset_class = filtersets.ContactGroupFilterSet
 

+ 26 - 0
netbox/tenancy/models/contacts.py

@@ -1,12 +1,14 @@
 from django.contrib.contenttypes.fields import GenericForeignKey
 from django.core.exceptions import ValidationError
 from django.db import models
+from django.db.models.expressions import RawSQL
 from django.urls import reverse
 from django.utils.translation import gettext_lazy as _
 
 from netbox.models import ChangeLoggedModel, NestedGroupModel, OrganizationalModel, PrimaryModel
 from netbox.models.features import CustomFieldsMixin, ExportTemplatesMixin, TagsMixin, has_feature
 from tenancy.choices import *
+from utilities.mptt import TreeManager
 
 __all__ = (
     'Contact',
@@ -16,10 +18,34 @@ __all__ = (
 )
 
 
+class ContactGroupManager(TreeManager):
+
+    def annotate_contacts(self):
+        """
+        Annotate the total number of Contacts belonging to each ContactGroup.
+
+        This returns both direct children and children of child groups. Raw SQL is used here to avoid double-counting
+        contacts which are assigned to multiple child groups of the parent.
+        """
+        return self.annotate(
+            contact_count=RawSQL(
+                "SELECT COUNT(DISTINCT m2m.contact_id)"
+                " FROM tenancy_contact_groups m2m"
+                " INNER JOIN tenancy_contactgroup cg ON m2m.contactgroup_id = cg.id"
+                " WHERE cg.tree_id = tenancy_contactgroup.tree_id"
+                " AND cg.lft >= tenancy_contactgroup.lft"
+                " AND cg.lft <= tenancy_contactgroup.rght",
+                ()
+            )
+        )
+
+
 class ContactGroup(NestedGroupModel):
     """
     An arbitrary collection of Contacts.
     """
+    objects = ContactGroupManager()
+
     class Meta:
         ordering = ['name']
         # Empty tuple triggers Django migration detection for MPTT indexes

+ 72 - 0
netbox/tenancy/tests/test_models.py

@@ -0,0 +1,72 @@
+from django.test import TestCase
+
+from tenancy.models import Contact, ContactGroup
+
+
+class ContactGroupTestCase(TestCase):
+
+    @classmethod
+    def setUpTestData(cls):
+        # Create a tree of contact groups:
+        #  - Group A
+        #    - Group A1
+        #    - Group A2
+        #  - Group B
+        cls.group_a = ContactGroup.objects.create(name='Group A', slug='group-a')
+        cls.group_a1 = ContactGroup.objects.create(name='Group A1', slug='group-a1', parent=cls.group_a)
+        cls.group_a2 = ContactGroup.objects.create(name='Group A2', slug='group-a2', parent=cls.group_a)
+        cls.group_b = ContactGroup.objects.create(name='Group B', slug='group-b')
+
+        # Create contacts
+        cls.contact1 = Contact.objects.create(name='Contact 1')
+        cls.contact2 = Contact.objects.create(name='Contact 2')
+        cls.contact3 = Contact.objects.create(name='Contact 3')
+        cls.contact4 = Contact.objects.create(name='Contact 4')
+
+    def test_annotate_contacts_direct(self):
+        """Contacts assigned directly to a group should be counted."""
+        self.contact1.groups.set([self.group_a])
+        self.contact2.groups.set([self.group_a])
+
+        queryset = ContactGroup.objects.annotate_contacts()
+        self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 2)
+
+    def test_annotate_contacts_cumulative(self):
+        """Contacts assigned to child groups should be included in the parent's count."""
+        self.contact1.groups.set([self.group_a1])
+        self.contact2.groups.set([self.group_a2])
+
+        queryset = ContactGroup.objects.annotate_contacts()
+        self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 2)
+        self.assertEqual(queryset.get(pk=self.group_a1.pk).contact_count, 1)
+        self.assertEqual(queryset.get(pk=self.group_a2.pk).contact_count, 1)
+
+    def test_annotate_contacts_no_double_counting(self):
+        """A contact assigned to multiple child groups must be counted only once for the parent."""
+        self.contact1.groups.set([self.group_a1, self.group_a2])
+
+        queryset = ContactGroup.objects.annotate_contacts()
+        self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 1)
+
+    def test_annotate_contacts_mixed(self):
+        """Test a mix of direct and inherited contacts with overlap."""
+        self.contact1.groups.set([self.group_a])
+        self.contact2.groups.set([self.group_a1])
+        self.contact3.groups.set([self.group_a1, self.group_a2])
+        self.contact4.groups.set([self.group_b])
+
+        queryset = ContactGroup.objects.annotate_contacts()
+        # Group A: contact1 (direct) + contact2 (via A1) + contact3 (via A1 & A2) = 3
+        self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 3)
+        # Group A1: contact2 + contact3 = 2
+        self.assertEqual(queryset.get(pk=self.group_a1.pk).contact_count, 2)
+        # Group A2: contact3 = 1
+        self.assertEqual(queryset.get(pk=self.group_a2.pk).contact_count, 1)
+        # Group B: contact4 = 1
+        self.assertEqual(queryset.get(pk=self.group_b.pk).contact_count, 1)
+
+    def test_annotate_contacts_empty(self):
+        """Groups with no contacts should return a count of zero."""
+        queryset = ContactGroup.objects.annotate_contacts()
+        self.assertEqual(queryset.get(pk=self.group_a.pk).contact_count, 0)
+        self.assertEqual(queryset.get(pk=self.group_b.pk).contact_count, 0)

+ 4 - 22
netbox/tenancy/views.py

@@ -205,13 +205,7 @@ class TenantBulkDeleteView(generic.BulkDeleteView):
 
 @register_model_view(ContactGroup, 'list', path='', detail=False)
 class ContactGroupListView(generic.ObjectListView):
-    queryset = ContactGroup.objects.add_related_count(
-        ContactGroup.objects.all(),
-        Contact,
-        'groups',
-        'contact_count',
-        cumulative=True
-    )
+    queryset = ContactGroup.objects.annotate_contacts()
     filterset = filtersets.ContactGroupFilterSet
     filterset_form = forms.ContactGroupFilterForm
     table = tables.ContactGroupTable
@@ -254,7 +248,7 @@ class ContactGroupView(GetRelatedModelsMixin, generic.ObjectView):
                 request,
                 groups,
                 extra=(
-                    (Contact.objects.restrict(request.user, 'view').filter(groups__in=groups), 'group_id'),
+                    (Contact.objects.restrict(request.user, 'view').filter(groups__in=groups).distinct(), 'group_id'),
                 ),
             ),
         }
@@ -280,13 +274,7 @@ class ContactGroupBulkImportView(generic.BulkImportView):
 
 @register_model_view(ContactGroup, 'bulk_edit', path='edit', detail=False)
 class ContactGroupBulkEditView(generic.BulkEditView):
-    queryset = ContactGroup.objects.add_related_count(
-        ContactGroup.objects.all(),
-        Contact,
-        'groups',
-        'contact_count',
-        cumulative=True
-    )
+    queryset = ContactGroup.objects.annotate_contacts()
     filterset = filtersets.ContactGroupFilterSet
     table = tables.ContactGroupTable
     form = forms.ContactGroupBulkEditForm
@@ -300,13 +288,7 @@ class ContactGroupBulkRenameView(generic.BulkRenameView):
 
 @register_model_view(ContactGroup, 'bulk_delete', path='delete', detail=False)
 class ContactGroupBulkDeleteView(generic.BulkDeleteView):
-    queryset = ContactGroup.objects.add_related_count(
-        ContactGroup.objects.all(),
-        Contact,
-        'groups',
-        'contact_count',
-        cumulative=True
-    )
+    queryset = ContactGroup.objects.annotate_contacts()
     filterset = filtersets.ContactGroupFilterSet
     table = tables.ContactGroupTable